This PR implements the feature described in #107036 for `no_grad`, `enable_grad` and `inference_mode`.
Users can still use the above as before but they can also use them without parentheses.
For example:
```python
import torch
a = torch.ones(1, requires_grad=True)
def do_something():
print(2 * a)
with torch.no_grad():
do_something() # tensor([2.])
torch.no_grad()(do_something)() # tensor([2.])
torch.no_grad(do_something)() # tensor([2.])
do_something() # tensor([2.], grad_fn=<MulBackward0>)
```
For `inference_mode`, decorating without parenthesis is equivalent to decorating with the default `mode=True`, similiar to how dataclasses behave (https://docs.python.org/3/library/dataclasses.html#module-contents)
Closes#107036
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107086
Approved by: https://github.com/albanD
Fixes#104985
Implemented `set_multithreading_enabled` C++ function to directly alter state rather than using `MultithreadingEnabled` class, which was automatically resetting the state when the object was destroyed. This behavior more closely aligns with set_grad_enabled which does work as expected. This allows us to change python class `set_multithreading_enabled` to act as both a function and context manager.
I also added a getter: `torch._C.is_multithreading_enabled`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105291
Approved by: https://github.com/albanD
There are some I can't easily switch due to reasons like:
- Dynamo modelling the guard
- BC concerns (for torch.autograd.set_multithreading_enabled)
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102642
Approved by: https://github.com/albanD
tldr; this should fix some minor perf regressions that were caused by adding more as_strided() calls in aot autograd.
This PR adds a new context manager, `torch.autograd._set_view_replay_enabled()`.
Context: AOT Autograd has special handling for "outputs that alias graph intermediates". E.g. given this function:
```
def f(x):
y = torch.mul(x, 2)
out = y.view(-1)
return out
```
AOT Autograd will do the following:
```
def fn_to_compile(x):
y = torch.mul(x, 2)
out = y.view(-1)
# return the graph intermediate
return y, out
compiled_fn = compile(fn_to_compile)
def wrapper(x):
y, out = compiled_fn(x)
# regenerate the alias of the graph intermediate
return out._view_func(y)
```
What's annoying is that `out._view_func()` will result in a `.as_strided` call, because `out` is an ordinary runtime tensor. This (likely?) caused a perf regression, because when running the backward, out `as_strided_backward()` is slower than our `view_backward()`.
In this PR, I added some TLS for instructing autograd to do view replay instead of as_strided, even when given a normal tensor. I'm definitely interested in thoughts from autograd folks (cc @albanD @soulitzer). A few points that I want to bring up:
(1) One reason that this API seems generally useful to me is because of the case where you `torch.compile()` a function, and you pass in two inputs that alias each other, and mutate one of the inputs. Autograd is forced to add a bunch of as_strided() calls into the graph when this happens, but this would give users an escape hatch for better compiled perf in this situation
(2) To be fair, AOT Autograd probably won't need this TLS in the long term. There's a better (more complicated) solution, where AOT Autograd manually precomputes the view chain off of graph intermediates during tracing, and re-applies them at runtime. This is kind of complicated though and feels lower priority to implement immediately.
(3) Given all of that I made the API private, but lmk what you all think.
This is a followup of https://github.com/pytorch/pytorch/pull/92255.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92588
Approved by: https://github.com/ezyang, https://github.com/albanD
Copied the type hints from the other context managers.
Not sure how to add type hints for `clone` since it returns the same class. The `Self` type isn't introduced until Python 3.11 and mypy just recently added support for it. Could also use `"inference_mode"` with quotes to avoid using it before it's declared, or `from __future__ import annotations` to allow its use without quotes. Or we could just skip it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94223
Approved by: https://github.com/albanD
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Summary:
This fixes the case when `torch.inference_mode` is called with `mode=False` (disabled). When used as a decorator, it ignored the argument and enabled inference mode anyway.
`_DecoratorContextManager` is changed so that a new instance is a copy instead of a new instance with default parameters.
I also added more tests to cover this case.
Current behaviour:
```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
... return x * x
...
>>> out = func(x)
>>> out.requires_grad
False
```
New behaviour (fixed):
```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
... return x * x
...
>>> out = func(x)
>>> out.requires_grad
True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68617
Reviewed By: mrshenli
Differential Revision: D32958434
Pulled By: albanD
fbshipit-source-id: 133c69970ef8bffb9fc9ab5142dedcffc4c32945
Summary:
Adds a note explaining the difference between several often conflated mechanisms in the autograd note
Also adds a link to this note from the docs in `grad_mode` and `nn.module`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58513
Reviewed By: gchanan
Differential Revision: D28651129
Pulled By: soulitzer
fbshipit-source-id: af9eb1749b641fc1b632815634eea36bf7979156
Summary:
Fixes https://github.com/pytorch/pytorch/issues/56608
- Adds binding to the `c10::InferenceMode` RAII class in `torch._C._autograd.InferenceMode` through pybind. Also binds the `torch.is_inference_mode` function.
- Adds context manager `torch.inference_mode` to manage an instance of `c10::InferenceMode` (global). Implemented in `torch.autograd.grad_mode.py` to reuse the `_DecoratorContextManager` class.
- Adds some tests based on those linked in the issue + several more for just the context manager
Issues/todos (not necessarily for this PR):
- Improve short inference mode description
- Small example
- Improved testing since there is no direct way of checking TLS/dispatch keys
-
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58045
Reviewed By: agolynski
Differential Revision: D28390595
Pulled By: soulitzer
fbshipit-source-id: ae98fa036c6a2cf7f56e0fd4c352ff804904752c
Summary:
Change from self to self._class_() in _DecoratorManager to ensure a new object is every time a function is called recursively
Fixes https://github.com/pytorch/pytorch/issues/44531
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44633
Reviewed By: agolynski
Differential Revision: D23783601
Pulled By: albanD
fbshipit-source-id: a818664dee7bdb061a40ede27ef99e9546fc80bb
Summary:
- Add `torch._C` bindings from `torch/csrc/autograd/init.cpp`
- Renamed `torch._C.set_grad_enabled` to `torch._C._set_grad_enabled`
so it doesn't conflict with torch.set_grad_enabled anymore
This is a continuation of gh-38201. All I did was resolve merge conflicts and finish the annotation of `_DecoratorContextManager.__call__` that ezyang started in the first commit.
~Reverts commit b5cd3a80bb, which was only motivated by not having `typing_extensions` available.~ (JIT can't be made to understand `Literal[False]`, so keep as is).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43415
Reviewed By: ngimel
Differential Revision: D23301168
Pulled By: malfet
fbshipit-source-id: cb5290f2e556b4036592655b9fe54564cbb036f6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41371
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.
**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.
**Fixes**
This commit fixes#40259.
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D22649519
Pulled By: SplitInfinity
fbshipit-source-id: 7fa675d04835377666dfd0ca4e6bc393dc541ab9
Summary:
# What's this
Just a small bug fix related to typing stubs.
I haven't open an issue. I will do so if I must open it, but this PR is very small (only 6 lines diff).
## What I encountered
pytorch 1.5.0 with mypy 0.770 behaves odd. The code is following:
```python
import torch
def f() -> int: # Mypy says: `error: Missing return statement`
with torch.no_grad():
return 1
```
No mypy error is expected, but actually mypy 0.770 warns about `Missing return statement`.
## This is because
`mypy >= 0.730` with `--warn-unreachable` says it's unreachable because `torch.no_grad()` may "swallows" the error in the return statement.
http://mypy-lang.blogspot.com/2019/09/mypy-730-released.html
Here is a small "swallowing" example:
```python
from typing import Generator
from contextlib import contextmanager
contextmanager
def swallow_zerodiv() -> Generator[None, None, None]:
try:
yield None
except ZeroDivisionError:
pass
finally:
pass
def div(a: int, b: int) -> float: # This function seems `(int, int) -> float` but actually `(int, int) -> Optional[float]` because ` return a / b` may be swallowed
with swallow_zerodiv():
return a / b
if __name__ == '__main__':
result = div(1, 0)
print(result, type(result)) # None <class 'NoneType'>
```
To supress this behavior, one can tell mypy not to swallow any exceptions, with returning `Literal[False]` or `None` in `__exit__` method of the context manager.
# What I did
Return `None` instead of `bool` to tell mypy that "I never swallow your exception".
I chose `None` because I cannot interpret `Literal[False]` without typing_extensions in `python <=3.7`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39324
Differential Revision: D21833651
Pulled By: albanD
fbshipit-source-id: d5cad2e5e19068bd68dc773e997bf13f7e60f4de
Summary:
Closes https://github.com/pytorch/pytorch/issues/31497
This allows `torch.no_grad` and `torch.enable_grad` to be used as decorators for generator functions. In which case it disables/enables grad only inside the body of the generator and restores the context outside of the generator.
https://github.com/pytorch/pytorch/issues/31497 doesn't include a complete reproducer but the included test with `torch.is_grad_enabled` show this is working where it failed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31792
Differential Revision: D19274971
Pulled By: albanD
fbshipit-source-id: fde6d3fd95d76c8d324ad02db577213a4b68ccbe
Summary:
- Fix broken sparse_coo_examples, update output
- Tensor(...) to tensor(...)
- Fix arguments to math.log to be floats
While the last might be debateable, mypy currently complains when passing an int to math.log. As it is not essential for our examples, let's be clean w.r.t. other people's expectations.
These popped up while checking examples in the context of #12500 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12707
Differential Revision: D10415256
Pulled By: SsnL
fbshipit-source-id: c907b576b02cb0f89d8f261173dbf4b3175b4b8d