mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
# What's this
Just a small bug fix related to typing stubs.
I haven't open an issue. I will do so if I must open it, but this PR is very small (only 6 lines diff).
## What I encountered
pytorch 1.5.0 with mypy 0.770 behaves odd. The code is following:
```python
import torch
def f() -> int: # Mypy says: `error: Missing return statement`
with torch.no_grad():
return 1
```
No mypy error is expected, but actually mypy 0.770 warns about `Missing return statement`.
## This is because
`mypy >= 0.730` with `--warn-unreachable` says it's unreachable because `torch.no_grad()` may "swallows" the error in the return statement.
http://mypy-lang.blogspot.com/2019/09/mypy-730-released.html
Here is a small "swallowing" example:
```python
from typing import Generator
from contextlib import contextmanager
contextmanager
def swallow_zerodiv() -> Generator[None, None, None]:
try:
yield None
except ZeroDivisionError:
pass
finally:
pass
def div(a: int, b: int) -> float: # This function seems `(int, int) -> float` but actually `(int, int) -> Optional[float]` because ` return a / b` may be swallowed
with swallow_zerodiv():
return a / b
if __name__ == '__main__':
result = div(1, 0)
print(result, type(result)) # None <class 'NoneType'>
```
To supress this behavior, one can tell mypy not to swallow any exceptions, with returning `Literal[False]` or `None` in `__exit__` method of the context manager.
# What I did
Return `None` instead of `bool` to tell mypy that "I never swallow your exception".
I chose `None` because I cannot interpret `Literal[False]` without typing_extensions in `python <=3.7`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39324
Differential Revision: D21833651
Pulled By: albanD
fbshipit-source-id: d5cad2e5e19068bd68dc773e997bf13f7e60f4de
107 lines
4.2 KiB
Python
107 lines
4.2 KiB
Python
import torch
|
|
import warnings
|
|
|
|
from typing import Any
|
|
|
|
class detect_anomaly(object):
|
|
r"""Context-manager that enable anomaly detection for the autograd engine.
|
|
|
|
This does two things:
|
|
- Running the forward pass with detection enabled will allow the backward
|
|
pass to print the traceback of the forward operation that created the failing
|
|
backward function.
|
|
- Any backward computation that generate "nan" value will raise an error.
|
|
|
|
.. warning::
|
|
This mode should be enabled only for debugging as the different tests
|
|
will slow down your program execution.
|
|
|
|
Example:
|
|
|
|
>>> import torch
|
|
>>> from torch import autograd
|
|
>>> class MyFunc(autograd.Function):
|
|
... @staticmethod
|
|
... def forward(ctx, inp):
|
|
... return inp.clone()
|
|
... @staticmethod
|
|
... def backward(ctx, gO):
|
|
... # Error during the backward pass
|
|
... raise RuntimeError("Some error in backward")
|
|
... return gO.clone()
|
|
>>> def run_fn(a):
|
|
... out = MyFunc.apply(a)
|
|
... return out.sum()
|
|
>>> inp = torch.rand(10, 10, requires_grad=True)
|
|
>>> out = run_fn(inp)
|
|
>>> out.backward()
|
|
Traceback (most recent call last):
|
|
File "<stdin>", line 1, in <module>
|
|
File "/your/pytorch/install/torch/tensor.py", line 93, in backward
|
|
torch.autograd.backward(self, gradient, retain_graph, create_graph)
|
|
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
|
|
allow_unreachable=True) # allow_unreachable flag
|
|
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
|
|
return self._forward_cls.backward(self, *args)
|
|
File "<stdin>", line 8, in backward
|
|
RuntimeError: Some error in backward
|
|
>>> with autograd.detect_anomaly():
|
|
... inp = torch.rand(10, 10, requires_grad=True)
|
|
... out = run_fn(inp)
|
|
... out.backward()
|
|
Traceback of forward call that caused the error:
|
|
File "tmp.py", line 53, in <module>
|
|
out = run_fn(inp)
|
|
File "tmp.py", line 44, in run_fn
|
|
out = MyFunc.apply(a)
|
|
Traceback (most recent call last):
|
|
File "<stdin>", line 4, in <module>
|
|
File "/your/pytorch/install/torch/tensor.py", line 93, in backward
|
|
torch.autograd.backward(self, gradient, retain_graph, create_graph)
|
|
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
|
|
allow_unreachable=True) # allow_unreachable flag
|
|
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
|
|
return self._forward_cls.backward(self, *args)
|
|
File "<stdin>", line 8, in backward
|
|
RuntimeError: Some error in backward
|
|
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.prev = torch.is_anomaly_enabled()
|
|
warnings.warn('Anomaly Detection has been enabled. '
|
|
'This mode will increase the runtime '
|
|
'and should only be enabled for debugging.', stacklevel=2)
|
|
|
|
def __enter__(self) -> None:
|
|
torch.set_anomaly_enabled(True)
|
|
|
|
def __exit__(self, *args: Any) -> None:
|
|
torch.set_anomaly_enabled(self.prev)
|
|
|
|
|
|
class set_detect_anomaly(object):
|
|
r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
|
|
|
|
``set_detect_anomaly`` will enable or disable the autograd anomaly detection
|
|
based on its argument :attr:`mode`.
|
|
It can be used as a context-manager or as a function.
|
|
|
|
See ``detect_anomaly`` above for details of the anomaly detection behaviour.
|
|
|
|
Arguments:
|
|
mode (bool): Flag whether to enable anomaly detection (``True``),
|
|
or disable (``False``).
|
|
|
|
"""
|
|
|
|
def __init__(self, mode: bool) -> None:
|
|
self.prev = torch.is_anomaly_enabled()
|
|
torch.set_anomaly_enabled(mode)
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, *args: Any) -> None:
|
|
torch.set_anomaly_enabled(self.prev)
|