mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update on "[PyTorch] Make TORCH_CHECK less likely to interfere with inlining"
Now it is smaller and calls to an out-of-line function in case of failure. Differential Revision: [D25481308](https://our.internmc.facebook.com/intern/diff/D25481308/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25481308/)! [ghstack-poisoned]
This commit is contained in:
commit
818cf218fc
|
|
@ -258,6 +258,31 @@ def test_exception_early_stop_asap(setup_rpc):
|
|||
assert counter == 2
|
||||
|
||||
|
||||
def test_nested_input(setup_rpc):
|
||||
class NestedInput(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc_a = nn.Linear(1, 1)
|
||||
self.fc_b = nn.Linear(1, 1)
|
||||
|
||||
def forward(self, inp):
|
||||
return inp
|
||||
|
||||
model = nn.Sequential(NestedInput())
|
||||
model = Pipe(model, chunks=2)
|
||||
|
||||
a = torch.rand(10, 1, requires_grad=True)
|
||||
b = torch.rand(10, 1, requires_grad=True)
|
||||
|
||||
# TypeError: expected Tensor, but got tuple
|
||||
with pytest.raises(TypeError):
|
||||
model((a, (a, b))).local_value()
|
||||
|
||||
# TypeError: expected Tensor, but got list
|
||||
with pytest.raises(TypeError):
|
||||
model((a, [a, b])).local_value()
|
||||
|
||||
|
||||
def test_input_pair(setup_rpc):
|
||||
class Two(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -282,6 +307,17 @@ def test_input_pair(setup_rpc):
|
|||
assert a.grad is not None
|
||||
assert b.grad is not None
|
||||
|
||||
# Test with list.
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
a_out, b_out = model([a, b]).local_value()
|
||||
loss = (a_out + b_out).mean()
|
||||
loss.backward()
|
||||
|
||||
assert a.grad is not None
|
||||
assert b.grad is not None
|
||||
|
||||
|
||||
|
||||
def test_input_singleton(setup_rpc):
|
||||
class One(nn.Module):
|
||||
|
|
@ -305,6 +341,18 @@ def test_input_singleton(setup_rpc):
|
|||
assert all(p.grad is not None for p in model.parameters())
|
||||
assert a.grad is not None
|
||||
|
||||
# Test with list
|
||||
a.grad = None
|
||||
for p in model.parameters():
|
||||
p.grad = None
|
||||
|
||||
(a_out,) = model([a]).local_value()
|
||||
loss = a_out.mean()
|
||||
loss.backward()
|
||||
|
||||
assert all(p.grad is not None for p in model.parameters())
|
||||
assert a.grad is not None
|
||||
|
||||
|
||||
def test_input_varargs(setup_rpc):
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
|
|
@ -336,7 +384,7 @@ def test_non_tensor(setup_rpc):
|
|||
model("hello")
|
||||
|
||||
|
||||
def test_non_tensor_tuple(setup_rpc):
|
||||
def test_non_tensor_sequence(setup_rpc):
|
||||
class NonTensorTuple(nn.Module):
|
||||
def forward(self, x):
|
||||
return (x, "hello")
|
||||
|
|
@ -353,6 +401,10 @@ def test_non_tensor_tuple(setup_rpc):
|
|||
with pytest.raises(TypeError):
|
||||
model((x, "hello"))
|
||||
|
||||
# TypeError: expected Tensor to scatter, but got str
|
||||
with pytest.raises(TypeError):
|
||||
model([x, "hello"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
||||
def test_deferred_batch_norm(checkpoint, setup_rpc):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ Usage::
|
|||
pipe = Pipe(model, balance, chunks=8)
|
||||
|
||||
"""
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Union, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -32,7 +32,7 @@ __all__ = ["balance_by_time", "balance_by_size"]
|
|||
|
||||
Device = Union[torch.device, int, str]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""Per-layer profilers."""
|
||||
import copy
|
||||
import time
|
||||
from typing import Generator, List, Tuple, Union
|
||||
from typing import Generator, List, Union, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -20,7 +20,7 @@ __all__: List[str] = []
|
|||
|
||||
Device = Union[torch.device, int, str]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,16 @@ copied entirely.
|
|||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Deque,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Sequence,
|
||||
Tuple
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -40,7 +49,7 @@ from .phony import get_phony
|
|||
__all__ = ["is_checkpointing", "is_recomputing"]
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
# Types for shared memory between Checkpoint and Recompute.
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
and computation on the same GPU.
|
||||
"""
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
from typing import Deque, List, Optional, Tuple, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -18,7 +18,7 @@ from .stream import AbstractStream, current_stream, get_device, record_stream, u
|
|||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
|
||||
|
||||
# Common interface between :class:`Copy` and :class:`Wait`.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
"""Manipulation of micro-batches."""
|
||||
import typing
|
||||
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
|
||||
from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -15,7 +15,7 @@ import torch.cuda.comm
|
|||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
Function = Callable[[TensorOrTensors], TensorOrTensors]
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class Batch:
|
|||
def _setitem_by_index(self, index: int, value: Tensor) -> None:
|
||||
if not self.atomic:
|
||||
i = index
|
||||
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
|
||||
self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore
|
||||
return
|
||||
|
||||
if index != 0:
|
||||
|
|
@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None:
|
|||
TypeError: input is not a tensor or tensors.
|
||||
|
||||
"""
|
||||
if isinstance(input, tuple):
|
||||
if isinstance(input, Sequence):
|
||||
for x in input:
|
||||
check(x)
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
|
||||
return
|
||||
|
||||
if not isinstance(input, Tensor):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
"""The Pipe interface."""
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
|
@ -27,7 +27,7 @@ __all__ = ["Pipe"]
|
|||
Device = Union[torch.device, int, str]
|
||||
Devices = Union[Iterable[Device], List[Device]]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -310,11 +310,11 @@ class Pipe(Module):
|
|||
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
|
||||
modify the input and output signature of the underlying module. But
|
||||
there's type restriction. Input and output have to be a
|
||||
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
|
||||
:class:`~torch.Tensor` or a sequence of tensors. This restriction is
|
||||
applied at partition boundaries too.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch
|
||||
input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch
|
||||
|
||||
Returns:
|
||||
:class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""The pipeline parallelism of Pipe."""
|
||||
from queue import Queue
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
|
@ -25,7 +25,7 @@ from .worker import Task, create_workers, join_workers
|
|||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
|
|
@ -33,7 +34,7 @@ from .tracker import current_skip_tracker
|
|||
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
Tensors = Sequence[Tensor]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
StashPop = Union["stash", "pop"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user