Updated docs and added deprecation warnings to acknowledge a bool tensor (#22261)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22261
ghimport-source-id: 1611d62d056a04c0ad15ef662e594a3d206a78e2

Test Plan: Imported from OSS

Differential Revision: D16005990

Pulled By: izdeby

fbshipit-source-id: 2413824aa75a0755719e4df11acd21e6607e5a85
This commit is contained in:
Iurii Zdebskyi 2019-08-05 07:37:51 -07:00 committed by Facebook Github Bot
parent 520982d1df
commit 19c675178f
11 changed files with 199 additions and 167 deletions

View File

@ -15,6 +15,10 @@ static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices
std::vector<Tensor> result;
for (const auto & index : indices) {
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
AT_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) {
@ -122,4 +126,3 @@ struct AdvancedIndex {
}}

View File

@ -10,6 +10,8 @@ Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, Scalar value) {
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cpu::_th_masked_fill_(self, mask, value);
} else {
return legacy::cpu::_th_masked_fill_bool_(self, mask, value);
@ -20,6 +22,8 @@ Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & valu
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cpu::_th_masked_fill_(self, mask, value);
} else {
return legacy::cpu::_th_masked_fill_bool_(self, mask, value);
@ -30,6 +34,8 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cpu::_th_masked_scatter_(self, mask, source);
} else {
return legacy::cpu::_th_masked_scatter_bool_(self, mask, source);
@ -38,6 +44,8 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cpu::_th_masked_select(self, mask);
} else {
return legacy::cpu::_th_masked_select_bool(self, mask);

View File

@ -10,6 +10,8 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, Scalar value) {
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cuda::_th_masked_fill_(self, mask, value);
} else {
return legacy::cuda::_th_masked_fill_bool_(self, mask, value);
@ -20,6 +22,8 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cuda::_th_masked_fill_(self, mask, value);
} else {
return legacy::cuda::_th_masked_fill_bool_(self, mask, value);
@ -30,6 +34,8 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor &
// As we dispatch on self and TH is type-checked, we need different definitions.
// This can be fixed by moving to ATen.
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cuda::_th_masked_scatter_(self, mask, source);
} else {
return legacy::cuda::_th_masked_scatter_bool_(self, mask, source);
@ -38,6 +44,8 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor &
Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) {
if (mask.dtype() == at::ScalarType::Byte) {
AT_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
return legacy::cuda::_th_masked_select(self, mask);
} else {
return legacy::cuda::_th_masked_select_bool(self, mask);

View File

@ -469,9 +469,9 @@ view of a storage and defines numeric operations on it.
.. automethod:: where
.. automethod:: zero_
.. class:: ByteTensor()
.. class:: BoolTensor()
The following methods are unique to :class:`torch.ByteTensor`.
The following methods are unique to :class:`torch.BoolTensor`.
.. automethod:: all
.. automethod:: any

View File

@ -2,7 +2,7 @@ from common_utils import TestCase, run_tests
import torch
from torch import tensor
import unittest
import warnings
class TestIndexing(TestCase):
def test_single_int(self):
@ -44,9 +44,11 @@ class TestIndexing(TestCase):
v = torch.tensor([True, False, True], dtype=torch.bool)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8)
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
self.assertEqual(v[boolIndices], v[uint8Indices])
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
self.assertEqual(v[boolIndices], v[uint8Indices])
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
self.assertEquals(len(w), 2)
def test_bool_indices_accumulate(self):
mask = torch.zeros(size=(10, ), dtype=torch.bool)
@ -64,8 +66,10 @@ class TestIndexing(TestCase):
def test_byte_mask(self):
v = torch.randn(5, 7, 3)
mask = torch.ByteTensor([1, 0, 1, 1, 0])
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
self.assertEquals(len(w), 2)
v = torch.tensor([1.])
self.assertEqual(v[v == 0], torch.tensor([]))
@ -73,15 +77,19 @@ class TestIndexing(TestCase):
def test_byte_mask_accumulate(self):
mask = torch.zeros(size=(10, ), dtype=torch.uint8)
y = torch.ones(size=(10, 10))
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10)))
with warnings.catch_warnings(record=True) as w:
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10)))
self.assertEquals(len(w), 2)
def test_multiple_byte_mask(self):
v = torch.randn(5, 7, 3)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
mask2 = torch.ByteTensor([1, 1, 1])
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
self.assertEquals(len(w), 2)
def test_byte_mask2d(self):
v = torch.randn(5, 7, 3)
@ -121,7 +129,7 @@ class TestIndexing(TestCase):
y[idx] = -1
self.assertEqual(x, y)
mask = torch.zeros(4, 3).byte()
mask = torch.zeros(4, 3).bool()
y[mask] = -1
self.assertEqual(x, y)
@ -299,7 +307,11 @@ class TestIndexing(TestCase):
x = torch.arange(0., 16).view(4, 4)
b = torch.ByteTensor([True, False, True, False])
value = torch.tensor([3., 4., 5., 6.])
x[b] = value
with warnings.catch_warnings(record=True) as w:
x[b] = value
self.assertEquals(len(w), 1)
self.assertEqual(x[0], value)
self.assertEqual(x[1], torch.arange(4, 8))
self.assertEqual(x[2], value)
@ -378,7 +390,6 @@ class TestIndexing(TestCase):
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class NumpyTests(TestCase):
def test_index_no_floats(self):
a = torch.tensor([[[5.]]])
@ -484,10 +495,11 @@ class NumpyTests(TestCase):
index = tensor([False] * 6)
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
index = torch.ByteTensor(4, 4).zero_()
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
with warnings.catch_warnings(record=True) as w:
index = torch.ByteTensor(4, 4).zero_()
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
self.assertEquals(len(w), 2)
def test_boolean_indexing_onedim(self):
# Indexing a 2-dimensional array with
@ -606,6 +618,5 @@ class NumpyTests(TestCase):
expected = b.double().unsqueeze(1).expand(100, 100)
self.assertEqual(a, expected)
if __name__ == '__main__':
run_tests()

View File

@ -8285,28 +8285,30 @@ class _TestTorchMixin(object):
[True, False, True, False, True]], device=device))
def test_masked_scatter(self):
for dtype in [torch.uint8, torch.bool]:
num_copy, num_dest = 3, 10
dest = torch.randn(num_dest)
src = torch.randn(num_copy)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=dtype)
dest2 = dest.clone()
dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
j += 1
self.assertEqual(dest, dest2, 0)
# make source bigger than number of 1s in mask
src = torch.randn(num_dest)
dest.masked_scatter_(mask, src)
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
with warnings.catch_warnings(record=True) as w:
for dtype in [torch.uint8, torch.bool]:
num_copy, num_dest = 3, 10
dest = torch.randn(num_dest)
src = torch.randn(num_copy)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=dtype)
dest2 = dest.clone()
dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
j += 1
self.assertEqual(dest, dest2, 0)
# make source bigger than number of 1s in mask
src = torch.randn(num_dest)
dest.masked_scatter_(mask, src)
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)
self.assertEqual(len(w), 3)
def test_masked_scatter_bool_tensor(self):
for device in torch.testing.get_all_device_types():
@ -8323,40 +8325,44 @@ class _TestTorchMixin(object):
def test_masked_select(self):
for device in torch.testing.get_all_device_types():
for dtype in [torch.uint8, torch.bool]:
num_src = 10
src = torch.randn(num_src, device=device)
mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(dtype)
dst = src.masked_select(mask)
dst2 = []
for i in range(num_src):
if mask[i]:
dst2 += [src[i]]
self.assertEqual(dst, torch.tensor(dst2), 0)
with warnings.catch_warnings(record=True) as w:
for dtype in [torch.uint8, torch.bool]:
num_src = 10
src = torch.randn(num_src, device=device)
mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(dtype)
dst = src.masked_select(mask)
dst2 = []
for i in range(num_src):
if mask[i]:
dst2 += [src[i]]
self.assertEqual(dst, torch.tensor(dst2), 0)
dst3 = torch.empty_like(src, device=device)
torch.masked_select(src, mask, out=dst3)
self.assertEqual(dst3, torch.Tensor(dst2), 0)
dst3 = torch.empty_like(src, device=device)
torch.masked_select(src, mask, out=dst3)
self.assertEqual(dst3, torch.Tensor(dst2), 0)
self.assertEquals(len(w), 1)
def test_masked_fill(self):
for dtype in [torch.uint8, torch.bool]:
num_dest = 10
dst = torch.randn(num_dest)
mask = torch.rand(num_dest).mul(2).floor().to(dtype)
val = random.random()
dst2 = dst.clone()
dst.masked_fill_(mask, val)
for i in range(num_dest):
if mask[i]:
dst2[i] = val
self.assertEqual(dst, dst2, 0)
with warnings.catch_warnings(record=True) as w:
for dtype in [torch.uint8, torch.bool]:
num_dest = 10
dst = torch.randn(num_dest)
mask = torch.rand(num_dest).mul(2).floor().to(dtype)
val = random.random()
dst2 = dst.clone()
dst.masked_fill_(mask, val)
for i in range(num_dest):
if mask[i]:
dst2[i] = val
self.assertEqual(dst, dst2, 0)
# test non-contiguous case
dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1))
dst2 = dst.clone()
dst.masked_fill_((dst > 0).to(dtype), val)
dst2.masked_fill_((dst2 > 0).to(dtype), val)
self.assertEqual(dst, dst2, 0)
# test non-contiguous case
dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1))
dst2 = dst.clone()
dst.masked_fill_((dst > 0).to(dtype), val)
dst2.masked_fill_((dst2 > 0).to(dtype), val)
self.assertEqual(dst, dst2, 0)
self.assertEquals(len(w), 3)
def test_masked_fill_bool_tensor(self):
for device in torch.testing.get_all_device_types():

View File

@ -39,7 +39,7 @@ read gen_pyi for the gory details.
needed_modules = set()
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False"
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
@ -120,7 +120,7 @@ def type_to_python(typename, size=None):
'IntArrayRef[]': 'Union[_int, _size]',
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
'bool': 'bool',
'bool': '_bool',
'double': '_float',
'int64_t': '_int',
'accreal': 'Number',
@ -208,7 +208,7 @@ def sig_for_ops(opname):
tname = 'bool'
else:
tname = 'int'
if tname in {'float', 'int'}:
if tname in {'float', 'int', 'bool'}:
tname = 'builtins.' + tname
return ['def {}(self) -> {}: ...'.format(opname, tname)]
else:
@ -275,7 +275,7 @@ def generate_type_hints(fname, decls, is_tensor=False):
python_args += ["dtype: _dtype=None",
"layout: layout=strided",
"device: Union[_device, str, None]=None",
"requires_grad:bool=False"]
"requires_grad:_bool=False"]
python_args_s = ', '.join(python_args)
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
@ -317,9 +317,9 @@ def generate_type_hints(fname, decls, is_tensor=False):
def gen_nn_modules(out):
def replace_forward(m):
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
# would otherwise correctly point out that Module's descendants' `forward` declarations
# conflict with `Module`s. Specificlaly, `Module` defines `forward(self, *args)` while the
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
# would otherwise correctly point out that Module's descendants' `forward` declarations
# conflict with `Module`s. Specificlaly, `Module` defines `forward(self, *args)` while the
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
forward_def = m.group(0) + " # type: ignore"
call_def = re.sub(r'def forward', 'def __call__', forward_def)
@ -412,7 +412,7 @@ def gen_pyi(declarations_path, out):
unsorted_function_hints = collections.defaultdict(list)
unsorted_function_hints.update({
'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
@ -428,7 +428,7 @@ def gen_pyi(declarations_path, out):
'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'range': ['def range(start: Number, end: Number,'
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
@ -503,28 +503,28 @@ def gen_pyi(declarations_path, out):
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
" -> None: ...".format(INDICES)],
'tolist': ['def tolist(self) -> List: ...'],
'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
'element_size': ['def element_size(self) -> _int: ...'],
'dim': ['def dim(self) -> _int: ...'],
'ndimension': ['def ndimension(self) -> _int: ...'],
'nelement': ['def nelement(self) -> _int: ...'],
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: _bool=False) -> Tensor: ...'],
'numpy': ['def numpy(self) -> Any: ...'],
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
'storage': ['def storage(self) -> Storage: ...'],
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: _bool=False)'
' -> Union[str, Tensor]: ...'],
'get_device': ['def get_device(self) -> _int: ...'],
'contiguous': ['def contiguous(self) -> Tensor: ...'],
'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
'is_cuda': ['is_cuda: bool'],
'is_leaf': ['is_leaf: bool'],
'is_contiguous': ['def is_contiguous(self) -> _bool: ...'],
'is_cuda': ['is_cuda: _bool'],
'is_leaf': ['is_leaf: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
],
'item': ["def item(self) -> Number: ..."],
})
@ -541,7 +541,8 @@ def gen_pyi(declarations_path, out):
'def {}(self, value: Number,'
' other: Union[Tensor, Number]{})'
' -> Tensor: ...'.format(name, out_suffix))
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
'half', 'int', 'long', 'short', 'bool']
for name in simple_conversions:
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
@ -568,11 +569,11 @@ def gen_pyi(declarations_path, out):
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_class_hints = []
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage'):
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
legacy_class_hints.append('class {}(Storage): ...'.format(c))
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor'):
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
# Generate type signatures for dtype classes
@ -584,7 +585,7 @@ def gen_pyi(declarations_path, out):
for n in
['float32', 'float', 'float64', 'double', 'float16', 'half',
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
'complex32', 'complex64', 'complex128', 'quint8', 'qint8', 'qint32']]
'complex32', 'complex64', 'complex128', 'quint8', 'qint8', 'qint32', 'bool']]
# Write out the stub
# ~~~~~~~~~~~~~~~~~~

View File

@ -24,9 +24,9 @@ __all__ = [
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
'no_grad', 'enable_grad', 'rand', 'randn',
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'Tensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
]
################################################################################

View File

@ -44,6 +44,7 @@ per_tensor_affine: qscheme = ...
# is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool
class device:
type: str
@ -67,7 +68,7 @@ _qscheme = qscheme
_size = Union[Size, List[_int], Tuple[_int, ...]]
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float]
Number = Union[builtins.int, builtins.float, builtins.bool]
class Generator:
device: _device = ...
@ -85,7 +86,7 @@ class Tensor:
dtype: _dtype = ...
shape: Size = ...
device: _device = ...
requires_grad: bool = ...
requires_grad: _bool = ...
grad: Optional[Tensor] = ...
${tensor_method_hints}
@ -93,7 +94,7 @@ class Tensor:
# Manually defined methods from torch/tensor.py
def register_hook(self, hook: Callable) -> Any: ...
def retain_grad(self) -> None: ...
def is_shared(self) -> bool: ...
def is_shared(self) -> _bool: ...
def share_memory_(self) -> None: ...
# TODO: fill in the types for these, or otherwise figure out some
# way to not have to write these out again...
@ -115,15 +116,14 @@ ${dtype_class_hints}
# Pure Python functions defined in torch/__init__.py
def typename(obj) -> str: ...
def is_tensor(obj) -> bool: ...
def is_storage(obj) -> bool: ...
def is_tensor(obj) -> _bool: ...
def is_storage(obj) -> _bool: ...
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> bool: ...
def compiled_with_cxx11_abi() -> _bool: ...
# The return value of this function depends on the value of `as_tuple`,
# (similar to `unique`, `lu`, etc.); as such, it is not
# possible to type correctly
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[bool]=None): ...
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...

View File

@ -286,20 +286,20 @@ add_docstr_all('all',
r"""
.. function:: all() -> bool
Returns True if all elements in the tensor are non-zero, False otherwise.
Returns True if all elements in the tensor are True, False otherwise.
Example::
>>> a = torch.randn(1, 3).byte() % 2
>>> a = torch.rand(1, 2).bool()
>>> a
tensor([[1, 0, 0]], dtype=torch.uint8)
tensor([[False, True]], dtype=torch.bool)
>>> a.all()
tensor(0, dtype=torch.uint8)
tensor(False, dtype=torch.bool)
.. function:: all(dim, keepdim=False, out=None) -> Tensor
Returns True if all elements in each row of the tensor in the given
dimension :attr:`dim` are non-zero, False otherwise.
dimension :attr:`dim` are True, False otherwise.
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
@ -313,15 +313,16 @@ Args:
Example::
>>> a = torch.randn(4, 2).byte() % 2
>>> a = torch.rand(4, 2).bool()
>>> a
tensor([[0, 0],
[0, 0],
[0, 1],
[1, 1]], dtype=torch.uint8)
tensor([[True, True],
[True, False],
[True, True],
[True, True]], dtype=torch.bool)
>>> a.all(dim=1)
tensor([0, 0, 0, 1], dtype=torch.uint8)
tensor([ True, False, True, True], dtype=torch.bool)
>>> a.all(dim=0)
tensor([ True, False], dtype=torch.bool)
""")
add_docstr_all('allclose',
@ -335,20 +336,19 @@ add_docstr_all('any',
r"""
.. function:: any() -> bool
Returns True if any elements in the tensor are non-zero, False otherwise.
Returns True if any elements in the tensor are True, False otherwise.
Example::
>>> a = torch.randn(1, 3).byte() % 2
>>> a = torch.rand(1, 2).bool()
>>> a
tensor([[0, 0, 1]], dtype=torch.uint8)
tensor([[False, True]], dtype=torch.bool)
>>> a.any()
tensor(1, dtype=torch.uint8)
tensor(True, dtype=torch.bool)
.. function:: any(dim, keepdim=False, out=None) -> Tensor
Returns True if any elements in each row of the tensor in the given
dimension :attr:`dim` are non-zero, False otherwise.
dimension :attr:`dim` are True, False otherwise.
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
@ -362,15 +362,16 @@ Args:
Example::
>>> a = torch.randn(4, 2).byte() % 2
>>> a = torch.randn(4, 2) < 0
>>> a
tensor([[1, 0],
[0, 0],
[0, 1],
[0, 0]], dtype=torch.uint8)
>>> a.any(dim=1)
tensor([1, 0, 1, 0], dtype=torch.uint8)
tensor([[ True, True],
[False, True],
[ True, True],
[False, False]])
>>> a.any(1)
tensor([ True, True, True, False])
>>> a.any(0)
tensor([True, True])
""")
add_docstr_all('apply_',
@ -1510,13 +1511,13 @@ add_docstr_all('masked_scatter_',
masked_scatter_(mask, source)
Copies elements from :attr:`source` into :attr:`self` tensor at positions where
the :attr:`mask` is one.
the :attr:`mask` is True.
The shape of :attr:`mask` must be :ref:`broadcastable <broadcasting-semantics>`
with the shape of the underlying tensor. The :attr:`source` should have at least
as many elements as the number of ones in :attr:`mask`
Args:
mask (ByteTensor): the binary mask
mask (BoolTensor): the boolean mask
source (Tensor): the tensor to copy from
.. note::
@ -1530,12 +1531,12 @@ add_docstr_all('masked_fill_',
masked_fill_(mask, value)
Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is
one. The shape of :attr:`mask` must be
True. The shape of :attr:`mask` must be
:ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
tensor.
Args:
mask (ByteTensor): the binary mask
mask (BoolTensor): the boolean mask
value (float): the value to fill in with
""")

View File

@ -1694,16 +1694,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor. Must be a `ByteTensor`
out (Tensor, optional): the output tensor. Must be a `BoolTensor`
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
Example::
>>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 1, 0],
[ 0, 1]], dtype=torch.uint8)
tensor([[True, False], [False, True]])
""")
add_docstr(torch.equal,
@ -2002,16 +2001,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be a `ByteTensor`
out (Tensor, optional): the output tensor that must be a `BoolTensor`
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
Example::
>>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 1, 1],
[ 0, 1]], dtype=torch.uint8)
tensor([[True, True], [False, True]])
""")
add_docstr(torch.geqrf,
@ -2165,16 +2163,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be a `ByteTensor`
out (Tensor, optional): the output tensor that must be a `BoolTensor`
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
Example::
>>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 0, 1],
[ 0, 0]], dtype=torch.uint8)
tensor([[False, True], [False, False]])
""")
add_docstr(torch.histc,
@ -2288,12 +2285,12 @@ Arguments:
input (Tensor): A tensor to check
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
Tensor: A ``torch.BoolTensor`` containing a True at each location of `NaN` elements.
Example::
>>> torch.isnan(torch.tensor([1, float('nan'), 2]))
tensor([ 0, 1, 0], dtype=torch.uint8)
tensor([False, True, False])
""")
add_docstr(torch.is_floating_point,
@ -2359,16 +2356,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be a `ByteTensor`
out (Tensor, optional): the output tensor that must be a `BoolTensor`
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
Example::
>>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 1, 0],
[ 1, 1]], dtype=torch.uint8)
tensor([[True, False], [True, True]])
""")
add_docstr(torch.lerp,
@ -2682,16 +2678,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be a `ByteTensor`
out (Tensor, optional): the output tensor that must be a `BoolTensor`
Returns:
Tensor: A `torch.ByteTensor` containing a 1 at each location where comparison is true
Tensor: A `torch.BoolTensor` containing a True at each location where comparison is true
Example::
>>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 0, 0],
[ 1, 0]], dtype=torch.uint8)
tensor([[False, False], [True, False]])
""")
add_docstr(torch.lu_solve,
@ -2723,7 +2718,7 @@ add_docstr(torch.masked_select,
masked_select(input, mask, out=None) -> Tensor
Returns a new 1-D tensor which indexes the :attr:`input` tensor according to
the binary mask :attr:`mask` which is a `ByteTensor`.
the boolean mask :attr:`mask` which is a `BoolTensor`.
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need
to match, but they must be :ref:`broadcastable <broadcasting-semantics>`.
@ -2733,7 +2728,7 @@ to match, but they must be :ref:`broadcastable <broadcasting-semantics>`.
Args:
input (Tensor): the input data
mask (ByteTensor): the tensor containing the binary mask to index with
mask (BoolTensor): the tensor containing the boolean mask to index with
out (Tensor, optional): the output tensor
Example::
@ -2745,9 +2740,9 @@ Example::
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[ 0, 0, 0, 0],
[ 0, 1, 1, 1],
[ 0, 0, 0, 1]], dtype=torch.uint8)
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
""")
@ -3495,16 +3490,15 @@ The second argument can be a number or a tensor whose shape is
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be a `ByteTensor`
out (Tensor, optional): the output tensor that must be a `BoolTensor`
Returns:
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true.
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true.
Example::
>>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ 0, 1],
[ 1, 0]], dtype=torch.uint8)
tensor([[False, True], [True, False]])
""")
add_docstr(torch.neg,
@ -5987,9 +5981,9 @@ The operation is defined as:
The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable <broadcasting-semantics>`.
Arguments:
condition (ByteTensor): When True (nonzero), yield input, otherwise yield other
input (Tensor): values selected at indices where :attr:`condition` is ``True``
other (Tensor): values selected at indices where :attr:`condition` is ``False``
condition (BoolTensor): When True (nonzero), yield x, otherwise yield y
x (Tensor): values selected at indices where :attr:`condition` is ``True``
y (Tensor): values selected at indices where :attr:`condition` is ``False``
Returns:
Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other`