mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48467 The current API's forward method only accepted a Tensor or a Tuple of Tensors, making this more generic by accepting any Sequence of Tensors. ghstack-source-id: 118436340 Test Plan: waitforbuildbot Reviewed By: rohan-varma Differential Revision: D25181944 fbshipit-source-id: 4db251dad52c01abc69f3d327788f2e4289e6c9d
187 lines
5.3 KiB
Python
187 lines
5.3 KiB
Python
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""Manipulation of micro-batches."""
|
|
import typing
|
|
from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.cuda.comm
|
|
|
|
__all__: List[str] = []
|
|
|
|
|
|
Tensors = Sequence[Tensor]
|
|
TensorOrTensors = Union[Tensor, Tensors]
|
|
Function = Callable[[TensorOrTensors], TensorOrTensors]
|
|
|
|
|
|
class Batch:
|
|
"""An abstraction of an atomic tensor or a tuple of tensors. This
|
|
eliminates every boilerplate code to classify an atomic tensor or a tuple
|
|
of tensors.
|
|
::
|
|
|
|
x = generate_tensor_or_tensors()
|
|
x = Batch(x)
|
|
|
|
# in-place update
|
|
x[0] = F.apply(x[0])
|
|
x[:] = F.apply(*x)
|
|
|
|
# f(x) if x is a tensor.
|
|
# f(*x) if x is a tuple of tensors.
|
|
# y is also a batch.
|
|
y = x.call(f)
|
|
|
|
"""
|
|
|
|
def __init__(self, value: TensorOrTensors) -> None:
|
|
self.value = value
|
|
self.atomic = torch.is_tensor(value)
|
|
|
|
@property
|
|
def tensor(self) -> Tensor:
|
|
"""Retrieves the underlying tensor."""
|
|
if not self.atomic:
|
|
raise AttributeError("not atomic batch")
|
|
return cast(Tensor, self.value)
|
|
|
|
@property
|
|
def tensors(self) -> Tensors:
|
|
"""Retrieves the underlying tensors."""
|
|
if self.atomic:
|
|
raise AttributeError("batch is atomic")
|
|
return cast(Tensors, self.value)
|
|
|
|
@property
|
|
def tensor_or_tensors(self) -> TensorOrTensors:
|
|
"""Retrieves the underlying tensor or tensors regardless of type."""
|
|
return self.value
|
|
|
|
def call(self, function: Function) -> "Batch":
|
|
"""Calls a function by the underlying tensor or tensors. It also wraps
|
|
the output with :class:`Batch`.
|
|
"""
|
|
return Batch(function(self.value))
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
|
|
|
|
def __iter__(self) -> Iterator[Tensor]:
|
|
if self.atomic:
|
|
yield self.tensor
|
|
else:
|
|
yield from self.tensors
|
|
|
|
def __len__(self) -> int:
|
|
return 1 if self.atomic else len(self.tensors)
|
|
|
|
def __getitem__(self, index: int) -> Tensor:
|
|
if not self.atomic:
|
|
return self.tensors[index]
|
|
|
|
if index != 0:
|
|
raise IndexError("atomic batch allows index 0 only")
|
|
|
|
return self.tensor
|
|
|
|
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
|
|
@typing.overload
|
|
def __setitem__(self, index: int, value: Tensor) -> None:
|
|
...
|
|
|
|
@typing.overload
|
|
def __setitem__(self, index: slice, value: Tensors) -> None:
|
|
...
|
|
|
|
def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
|
|
if isinstance(index, int):
|
|
value = cast(Tensor, value)
|
|
self._setitem_by_index(index, value)
|
|
else:
|
|
value = cast(Tensors, value)
|
|
self._setitem_by_slice(index, value)
|
|
|
|
def _setitem_by_index(self, index: int, value: Tensor) -> None:
|
|
if not self.atomic:
|
|
i = index
|
|
self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore
|
|
return
|
|
|
|
if index != 0:
|
|
raise IndexError("atomic batch allows index 0 only")
|
|
|
|
self.value = value
|
|
|
|
def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
|
|
if not (index.start is index.stop is index.step is None):
|
|
raise NotImplementedError("only slice [:] supported")
|
|
|
|
if not self.atomic:
|
|
self.value = value
|
|
return
|
|
|
|
if len(value) != 1:
|
|
raise IndexError("atomic batch cannot be replaced with multiple tensors")
|
|
|
|
self.value = value[0]
|
|
|
|
|
|
def check(input: TensorOrTensors) -> None:
|
|
"""Checks whether the input is a tensor or tensors.
|
|
|
|
Raises:
|
|
TypeError: input is not a tensor or tensors.
|
|
|
|
"""
|
|
if isinstance(input, Sequence):
|
|
for x in input:
|
|
if not isinstance(x, Tensor):
|
|
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
|
|
return
|
|
|
|
if not isinstance(input, Tensor):
|
|
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
|
|
|
|
|
|
def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
|
|
"""Splits an input mini-batch into multiple micro-batches."""
|
|
inputs: Iterable[TensorOrTensors]
|
|
|
|
if isinstance(input, Tensor):
|
|
inputs = input.chunk(chunks)
|
|
else:
|
|
rotated: List[Tensors] = []
|
|
|
|
for tensor in input:
|
|
tensors = tensor.chunk(chunks)
|
|
rotated.append(cast(Tensors, tensors))
|
|
|
|
inputs = zip(*rotated)
|
|
|
|
return [Batch(x) for x in inputs]
|
|
|
|
|
|
def gather(outputs: List[Batch]) -> TensorOrTensors:
|
|
"""Concatenates output micro-batches into a mini-batch."""
|
|
output: TensorOrTensors
|
|
|
|
if outputs[0].atomic:
|
|
tensors = tuple(b.tensor for b in outputs)
|
|
output = torch.cat(tensors)
|
|
else:
|
|
rotated = [b.tensors for b in outputs]
|
|
output_buf = []
|
|
|
|
for tensors in zip(*rotated):
|
|
output_buf.append(torch.cat(tensors))
|
|
|
|
output = tuple(output_buf)
|
|
|
|
return output
|