Update on "[PyTorch] Make TORCH_CHECK less likely to interfere with inlining"

Now it is smaller and calls to an out-of-line function in
case of failure.

Differential Revision: [D25481308](https://our.internmc.facebook.com/intern/diff/D25481308/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25481308/)!

[ghstack-poisoned]
This commit is contained in:
Scott Wolchok 2020-12-14 11:38:21 -08:00
commit 818cf218fc
9 changed files with 84 additions and 21 deletions

View File

@ -258,6 +258,31 @@ def test_exception_early_stop_asap(setup_rpc):
assert counter == 2
def test_nested_input(setup_rpc):
class NestedInput(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)
def forward(self, inp):
return inp
model = nn.Sequential(NestedInput())
model = Pipe(model, chunks=2)
a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)
# TypeError: expected Tensor, but got tuple
with pytest.raises(TypeError):
model((a, (a, b))).local_value()
# TypeError: expected Tensor, but got list
with pytest.raises(TypeError):
model((a, [a, b])).local_value()
def test_input_pair(setup_rpc):
class Two(nn.Module):
def __init__(self):
@ -282,6 +307,17 @@ def test_input_pair(setup_rpc):
assert a.grad is not None
assert b.grad is not None
# Test with list.
a.grad = None
b.grad = None
a_out, b_out = model([a, b]).local_value()
loss = (a_out + b_out).mean()
loss.backward()
assert a.grad is not None
assert b.grad is not None
def test_input_singleton(setup_rpc):
class One(nn.Module):
@ -305,6 +341,18 @@ def test_input_singleton(setup_rpc):
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None
# Test with list
a.grad = None
for p in model.parameters():
p.grad = None
(a_out,) = model([a]).local_value()
loss = a_out.mean()
loss.backward()
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None
def test_input_varargs(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
@ -336,7 +384,7 @@ def test_non_tensor(setup_rpc):
model("hello")
def test_non_tensor_tuple(setup_rpc):
def test_non_tensor_sequence(setup_rpc):
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
@ -353,6 +401,10 @@ def test_non_tensor_tuple(setup_rpc):
with pytest.raises(TypeError):
model((x, "hello"))
# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model([x, "hello"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint, setup_rpc):

View File

@ -18,7 +18,7 @@ Usage::
pipe = Pipe(model, balance, chunks=8)
"""
from typing import List, Tuple, Union
from typing import List, Union, Sequence
import torch
from torch import Tensor
@ -32,7 +32,7 @@ __all__ = ["balance_by_time", "balance_by_size"]
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

View File

@ -7,7 +7,7 @@
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
from typing import Generator, List, Union, Sequence
import torch
from torch import Tensor
@ -20,7 +20,7 @@ __all__: List[str] = []
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

View File

@ -27,7 +27,16 @@ copied entirely.
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Deque,
Generator,
List,
Optional,
Union,
Sequence,
Tuple
)
import torch
from torch import Tensor
@ -40,7 +49,7 @@ from .phony import get_phony
__all__ = ["is_checkpointing", "is_recomputing"]
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
# Types for shared memory between Checkpoint and Recompute.

View File

@ -8,7 +8,7 @@
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
from typing import Deque, List, Optional, Tuple, Sequence
import torch
from torch import Tensor
@ -18,7 +18,7 @@ from .stream import AbstractStream, current_stream, get_device, record_stream, u
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
# Common interface between :class:`Copy` and :class:`Wait`.

View File

@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence
import torch
from torch import Tensor
@ -15,7 +15,7 @@ import torch.cuda.comm
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]
@ -110,7 +110,7 @@ class Batch:
def _setitem_by_index(self, index: int, value: Tensor) -> None:
if not self.atomic:
i = index
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore
return
if index != 0:
@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None:
TypeError: input is not a tensor or tensors.
"""
if isinstance(input, tuple):
if isinstance(input, Sequence):
for x in input:
check(x)
if not isinstance(x, Tensor):
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
return
if not isinstance(input, Tensor):

View File

@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence
import torch
from torch import Tensor, nn
@ -27,7 +27,7 @@ __all__ = ["Pipe"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
if TYPE_CHECKING:
@ -310,11 +310,11 @@ class Pipe(Module):
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
:class:`~torch.Tensor` or a sequence of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch
input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch
Returns:
:class:`~torch.distributed.rpc.RRef` to the output of the mini-batch

View File

@ -7,7 +7,7 @@
"""The pipeline parallelism of Pipe."""
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
import torch
from torch import Tensor, nn
@ -25,7 +25,7 @@ from .worker import Task, create_workers, join_workers
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]

View File

@ -17,6 +17,7 @@ from typing import (
List,
Optional,
Set,
Sequence,
Tuple,
Type,
TypeVar,
@ -33,7 +34,7 @@ from .tracker import current_skip_tracker
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
StashPop = Union["stash", "pop"]