Return None instead False, and return bool to None in type stub (#39324)

Summary:
# What's this

Just a small bug fix related to typing stubs.
I haven't open an issue. I will do so if I must open it, but this PR is very small (only 6 lines diff).

## What I encountered

pytorch 1.5.0 with mypy 0.770 behaves odd. The code is following:
```python
import torch

def f() -> int:  # Mypy says: `error: Missing return statement`
    with torch.no_grad():
        return 1
```

No mypy error is expected, but actually mypy 0.770 warns about `Missing return statement`.

## This is because

`mypy >= 0.730` with `--warn-unreachable` says it's unreachable because `torch.no_grad()` may "swallows" the error in the return statement.
http://mypy-lang.blogspot.com/2019/09/mypy-730-released.html

Here is a small "swallowing" example:

```python
from typing import Generator
from contextlib import contextmanager

contextmanager
def swallow_zerodiv() -> Generator[None, None, None]:
    try:
        yield None
    except ZeroDivisionError:
        pass
    finally:
        pass

def div(a: int, b: int) -> float:  # This function seems `(int, int) -> float` but actually `(int, int) -> Optional[float]` because ` return a / b` may be swallowed
    with swallow_zerodiv():
        return a / b

if __name__ == '__main__':
    result = div(1, 0)
    print(result, type(result))  # None <class 'NoneType'>
```

To supress this behavior, one can tell mypy not to swallow any exceptions, with returning `Literal[False]` or `None` in `__exit__` method of the context manager.

# What I did

Return `None` instead of `bool` to tell mypy that "I never swallow your exception".
I chose `None` because I cannot interpret `Literal[False]` without typing_extensions in `python <=3.7`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39324

Differential Revision: D21833651

Pulled By: albanD

fbshipit-source-id: d5cad2e5e19068bd68dc773e997bf13f7e60f4de
This commit is contained in:
Keigo Kawamura 2020-06-02 10:43:55 -07:00 committed by Facebook GitHub Bot
parent bb0377bb24
commit b5cd3a80bb
3 changed files with 5 additions and 10 deletions

View File

@ -76,9 +76,8 @@ class detect_anomaly(object):
def __enter__(self) -> None:
torch.set_anomaly_enabled(True)
def __exit__(self, *args: Any) -> bool:
def __exit__(self, *args: Any) -> None:
torch.set_anomaly_enabled(self.prev)
return False
class set_detect_anomaly(object):
@ -103,6 +102,5 @@ class set_detect_anomaly(object):
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
def __exit__(self, *args: Any) -> None:
torch.set_anomaly_enabled(self.prev)
return False

View File

@ -68,7 +68,6 @@ class no_grad(_DecoratorContextManager):
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False
class enable_grad(_DecoratorContextManager):
@ -108,7 +107,6 @@ class enable_grad(_DecoratorContextManager):
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False
class set_grad_enabled(object):
@ -157,4 +155,3 @@ class set_grad_enabled(object):
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False

View File

@ -7,15 +7,15 @@ T = TypeVar('T', bound=FuncType)
class no_grad:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> bool: ...
def __exit__(self, *args: Any) -> None: ...
def __call__(self, func: T) -> T: ...
class enable_grad:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> bool: ...
def __exit__(self, *args: Any) -> None: ...
def __call__(self, func: T) -> T: ...
class set_grad_enabled:
def __init__(self, mode: bool) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> bool: ...
def __exit__(self, *args: Any) -> None: ...