diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 66c657dce8a..6a33f1780ef 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -76,9 +76,8 @@ class detect_anomaly(object): def __enter__(self) -> None: torch.set_anomaly_enabled(True) - def __exit__(self, *args: Any) -> bool: + def __exit__(self, *args: Any) -> None: torch.set_anomaly_enabled(self.prev) - return False class set_detect_anomaly(object): @@ -103,6 +102,5 @@ class set_detect_anomaly(object): def __enter__(self) -> None: pass - def __exit__(self, *args: Any) -> bool: + def __exit__(self, *args: Any) -> None: torch.set_anomaly_enabled(self.prev) - return False diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index ec58569c5ca..57492925c54 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -68,7 +68,6 @@ class no_grad(_DecoratorContextManager): def __exit__(self, *args): torch.set_grad_enabled(self.prev) - return False class enable_grad(_DecoratorContextManager): @@ -108,7 +107,6 @@ class enable_grad(_DecoratorContextManager): def __exit__(self, *args): torch.set_grad_enabled(self.prev) - return False class set_grad_enabled(object): @@ -157,4 +155,3 @@ class set_grad_enabled(object): def __exit__(self, *args): torch.set_grad_enabled(self.prev) - return False diff --git a/torch/autograd/grad_mode.pyi b/torch/autograd/grad_mode.pyi index ebe81398182..f04283b943e 100644 --- a/torch/autograd/grad_mode.pyi +++ b/torch/autograd/grad_mode.pyi @@ -7,15 +7,15 @@ T = TypeVar('T', bound=FuncType) class no_grad: def __enter__(self) -> None: ... - def __exit__(self, *args: Any) -> bool: ... + def __exit__(self, *args: Any) -> None: ... def __call__(self, func: T) -> T: ... class enable_grad: def __enter__(self) -> None: ... - def __exit__(self, *args: Any) -> bool: ... + def __exit__(self, *args: Any) -> None: ... def __call__(self, func: T) -> T: ... class set_grad_enabled: def __init__(self, mode: bool) -> None: ... def __enter__(self) -> None: ... - def __exit__(self, *args: Any) -> bool: ... + def __exit__(self, *args: Any) -> None: ...