mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Updated docs and added deprecation warnings to acknowledge a bool tensor (#22261)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22261 ghimport-source-id: 1611d62d056a04c0ad15ef662e594a3d206a78e2 Test Plan: Imported from OSS Differential Revision: D16005990 Pulled By: izdeby fbshipit-source-id: 2413824aa75a0755719e4df11acd21e6607e5a85
This commit is contained in:
parent
520982d1df
commit
19c675178f
|
|
@ -15,6 +15,10 @@ static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices
|
|||
std::vector<Tensor> result;
|
||||
for (const auto & index : indices) {
|
||||
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
||||
if (index.scalar_type() == kByte) {
|
||||
AT_WARN("indexing with dtype torch.uint8 is now deprecated," \
|
||||
" please use a dtype torch.bool instead.");
|
||||
}
|
||||
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
||||
// corresponding dimensions in self
|
||||
for (int64_t j = 0; j < index.dim(); j++) {
|
||||
|
|
@ -122,4 +126,3 @@ struct AdvancedIndex {
|
|||
|
||||
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, Scalar value) {
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_masked_fill_(self, mask, value);
|
||||
} else {
|
||||
return legacy::cpu::_th_masked_fill_bool_(self, mask, value);
|
||||
|
|
@ -20,6 +22,8 @@ Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & valu
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_masked_fill_(self, mask, value);
|
||||
} else {
|
||||
return legacy::cpu::_th_masked_fill_bool_(self, mask, value);
|
||||
|
|
@ -30,6 +34,8 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_masked_scatter_(self, mask, source);
|
||||
} else {
|
||||
return legacy::cpu::_th_masked_scatter_bool_(self, mask, source);
|
||||
|
|
@ -38,6 +44,8 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
|
|||
|
||||
Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_masked_select(self, mask);
|
||||
} else {
|
||||
return legacy::cpu::_th_masked_select_bool(self, mask);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, Scalar value) {
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_masked_fill_(self, mask, value);
|
||||
} else {
|
||||
return legacy::cuda::_th_masked_fill_bool_(self, mask, value);
|
||||
|
|
@ -20,6 +22,8 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_masked_fill_(self, mask, value);
|
||||
} else {
|
||||
return legacy::cuda::_th_masked_fill_bool_(self, mask, value);
|
||||
|
|
@ -30,6 +34,8 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor &
|
|||
// As we dispatch on self and TH is type-checked, we need different definitions.
|
||||
// This can be fixed by moving to ATen.
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_masked_scatter_(self, mask, source);
|
||||
} else {
|
||||
return legacy::cuda::_th_masked_scatter_bool_(self, mask, source);
|
||||
|
|
@ -38,6 +44,8 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor &
|
|||
|
||||
Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) {
|
||||
if (mask.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use a mask with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_masked_select(self, mask);
|
||||
} else {
|
||||
return legacy::cuda::_th_masked_select_bool(self, mask);
|
||||
|
|
|
|||
|
|
@ -469,9 +469,9 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: where
|
||||
.. automethod:: zero_
|
||||
|
||||
.. class:: ByteTensor()
|
||||
.. class:: BoolTensor()
|
||||
|
||||
The following methods are unique to :class:`torch.ByteTensor`.
|
||||
The following methods are unique to :class:`torch.BoolTensor`.
|
||||
|
||||
.. automethod:: all
|
||||
.. automethod:: any
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from common_utils import TestCase, run_tests
|
|||
import torch
|
||||
from torch import tensor
|
||||
import unittest
|
||||
|
||||
import warnings
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
def test_single_int(self):
|
||||
|
|
@ -44,9 +44,11 @@ class TestIndexing(TestCase):
|
|||
v = torch.tensor([True, False, True], dtype=torch.bool)
|
||||
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
|
||||
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8)
|
||||
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
|
||||
self.assertEqual(v[boolIndices], v[uint8Indices])
|
||||
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
|
||||
self.assertEqual(v[boolIndices], v[uint8Indices])
|
||||
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
|
||||
self.assertEquals(len(w), 2)
|
||||
|
||||
def test_bool_indices_accumulate(self):
|
||||
mask = torch.zeros(size=(10, ), dtype=torch.bool)
|
||||
|
|
@ -64,8 +66,10 @@ class TestIndexing(TestCase):
|
|||
def test_byte_mask(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
mask = torch.ByteTensor([1, 0, 1, 1, 0])
|
||||
self.assertEqual(v[mask].shape, (3, 7, 3))
|
||||
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertEqual(v[mask].shape, (3, 7, 3))
|
||||
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
|
||||
self.assertEquals(len(w), 2)
|
||||
|
||||
v = torch.tensor([1.])
|
||||
self.assertEqual(v[v == 0], torch.tensor([]))
|
||||
|
|
@ -73,15 +77,19 @@ class TestIndexing(TestCase):
|
|||
def test_byte_mask_accumulate(self):
|
||||
mask = torch.zeros(size=(10, ), dtype=torch.uint8)
|
||||
y = torch.ones(size=(10, 10))
|
||||
y.index_put_((mask, ), y[mask], accumulate=True)
|
||||
self.assertEqual(y, torch.ones(size=(10, 10)))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
y.index_put_((mask, ), y[mask], accumulate=True)
|
||||
self.assertEqual(y, torch.ones(size=(10, 10)))
|
||||
self.assertEquals(len(w), 2)
|
||||
|
||||
def test_multiple_byte_mask(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
# note: these broadcast together and are transposed to the first dim
|
||||
mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
|
||||
mask2 = torch.ByteTensor([1, 1, 1])
|
||||
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
|
||||
self.assertEquals(len(w), 2)
|
||||
|
||||
def test_byte_mask2d(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
|
|
@ -121,7 +129,7 @@ class TestIndexing(TestCase):
|
|||
y[idx] = -1
|
||||
self.assertEqual(x, y)
|
||||
|
||||
mask = torch.zeros(4, 3).byte()
|
||||
mask = torch.zeros(4, 3).bool()
|
||||
y[mask] = -1
|
||||
self.assertEqual(x, y)
|
||||
|
||||
|
|
@ -299,7 +307,11 @@ class TestIndexing(TestCase):
|
|||
x = torch.arange(0., 16).view(4, 4)
|
||||
b = torch.ByteTensor([True, False, True, False])
|
||||
value = torch.tensor([3., 4., 5., 6.])
|
||||
x[b] = value
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
x[b] = value
|
||||
self.assertEquals(len(w), 1)
|
||||
|
||||
self.assertEqual(x[0], value)
|
||||
self.assertEqual(x[1], torch.arange(4, 8))
|
||||
self.assertEqual(x[2], value)
|
||||
|
|
@ -378,7 +390,6 @@ class TestIndexing(TestCase):
|
|||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
class NumpyTests(TestCase):
|
||||
def test_index_no_floats(self):
|
||||
a = torch.tensor([[[5.]]])
|
||||
|
|
@ -484,10 +495,11 @@ class NumpyTests(TestCase):
|
|||
index = tensor([False] * 6)
|
||||
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
|
||||
|
||||
index = torch.ByteTensor(4, 4).zero_()
|
||||
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
|
||||
|
||||
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
index = torch.ByteTensor(4, 4).zero_()
|
||||
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
|
||||
self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
|
||||
self.assertEquals(len(w), 2)
|
||||
|
||||
def test_boolean_indexing_onedim(self):
|
||||
# Indexing a 2-dimensional array with
|
||||
|
|
@ -606,6 +618,5 @@ class NumpyTests(TestCase):
|
|||
expected = b.double().unsqueeze(1).expand(100, 100)
|
||||
self.assertEqual(a, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -8285,28 +8285,30 @@ class _TestTorchMixin(object):
|
|||
[True, False, True, False, True]], device=device))
|
||||
|
||||
def test_masked_scatter(self):
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_copy, num_dest = 3, 10
|
||||
dest = torch.randn(num_dest)
|
||||
src = torch.randn(num_copy)
|
||||
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=dtype)
|
||||
dest2 = dest.clone()
|
||||
dest.masked_scatter_(mask, src)
|
||||
j = 0
|
||||
for i in range(num_dest):
|
||||
if mask[i]:
|
||||
dest2[i] = src[j]
|
||||
j += 1
|
||||
self.assertEqual(dest, dest2, 0)
|
||||
|
||||
# make source bigger than number of 1s in mask
|
||||
src = torch.randn(num_dest)
|
||||
dest.masked_scatter_(mask, src)
|
||||
|
||||
# make src smaller. this should fail
|
||||
src = torch.randn(num_copy - 1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_copy, num_dest = 3, 10
|
||||
dest = torch.randn(num_dest)
|
||||
src = torch.randn(num_copy)
|
||||
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=dtype)
|
||||
dest2 = dest.clone()
|
||||
dest.masked_scatter_(mask, src)
|
||||
j = 0
|
||||
for i in range(num_dest):
|
||||
if mask[i]:
|
||||
dest2[i] = src[j]
|
||||
j += 1
|
||||
self.assertEqual(dest, dest2, 0)
|
||||
|
||||
# make source bigger than number of 1s in mask
|
||||
src = torch.randn(num_dest)
|
||||
dest.masked_scatter_(mask, src)
|
||||
|
||||
# make src smaller. this should fail
|
||||
src = torch.randn(num_copy - 1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
dest.masked_scatter_(mask, src)
|
||||
self.assertEqual(len(w), 3)
|
||||
|
||||
def test_masked_scatter_bool_tensor(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
|
|
@ -8323,40 +8325,44 @@ class _TestTorchMixin(object):
|
|||
|
||||
def test_masked_select(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_src = 10
|
||||
src = torch.randn(num_src, device=device)
|
||||
mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(dtype)
|
||||
dst = src.masked_select(mask)
|
||||
dst2 = []
|
||||
for i in range(num_src):
|
||||
if mask[i]:
|
||||
dst2 += [src[i]]
|
||||
self.assertEqual(dst, torch.tensor(dst2), 0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_src = 10
|
||||
src = torch.randn(num_src, device=device)
|
||||
mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(dtype)
|
||||
dst = src.masked_select(mask)
|
||||
dst2 = []
|
||||
for i in range(num_src):
|
||||
if mask[i]:
|
||||
dst2 += [src[i]]
|
||||
self.assertEqual(dst, torch.tensor(dst2), 0)
|
||||
|
||||
dst3 = torch.empty_like(src, device=device)
|
||||
torch.masked_select(src, mask, out=dst3)
|
||||
self.assertEqual(dst3, torch.Tensor(dst2), 0)
|
||||
dst3 = torch.empty_like(src, device=device)
|
||||
torch.masked_select(src, mask, out=dst3)
|
||||
self.assertEqual(dst3, torch.Tensor(dst2), 0)
|
||||
self.assertEquals(len(w), 1)
|
||||
|
||||
def test_masked_fill(self):
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_dest = 10
|
||||
dst = torch.randn(num_dest)
|
||||
mask = torch.rand(num_dest).mul(2).floor().to(dtype)
|
||||
val = random.random()
|
||||
dst2 = dst.clone()
|
||||
dst.masked_fill_(mask, val)
|
||||
for i in range(num_dest):
|
||||
if mask[i]:
|
||||
dst2[i] = val
|
||||
self.assertEqual(dst, dst2, 0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
for dtype in [torch.uint8, torch.bool]:
|
||||
num_dest = 10
|
||||
dst = torch.randn(num_dest)
|
||||
mask = torch.rand(num_dest).mul(2).floor().to(dtype)
|
||||
val = random.random()
|
||||
dst2 = dst.clone()
|
||||
dst.masked_fill_(mask, val)
|
||||
for i in range(num_dest):
|
||||
if mask[i]:
|
||||
dst2[i] = val
|
||||
self.assertEqual(dst, dst2, 0)
|
||||
|
||||
# test non-contiguous case
|
||||
dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1))
|
||||
dst2 = dst.clone()
|
||||
dst.masked_fill_((dst > 0).to(dtype), val)
|
||||
dst2.masked_fill_((dst2 > 0).to(dtype), val)
|
||||
self.assertEqual(dst, dst2, 0)
|
||||
# test non-contiguous case
|
||||
dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1))
|
||||
dst2 = dst.clone()
|
||||
dst.masked_fill_((dst > 0).to(dtype), val)
|
||||
dst2.masked_fill_((dst2 > 0).to(dtype), val)
|
||||
self.assertEqual(dst, dst2, 0)
|
||||
self.assertEquals(len(w), 3)
|
||||
|
||||
def test_masked_fill_bool_tensor(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ read gen_pyi for the gory details.
|
|||
|
||||
needed_modules = set()
|
||||
|
||||
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
|
||||
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False"
|
||||
|
||||
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
|
||||
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
|
||||
|
|
@ -120,7 +120,7 @@ def type_to_python(typename, size=None):
|
|||
'IntArrayRef[]': 'Union[_int, _size]',
|
||||
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
|
||||
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
|
||||
'bool': 'bool',
|
||||
'bool': '_bool',
|
||||
'double': '_float',
|
||||
'int64_t': '_int',
|
||||
'accreal': 'Number',
|
||||
|
|
@ -208,7 +208,7 @@ def sig_for_ops(opname):
|
|||
tname = 'bool'
|
||||
else:
|
||||
tname = 'int'
|
||||
if tname in {'float', 'int'}:
|
||||
if tname in {'float', 'int', 'bool'}:
|
||||
tname = 'builtins.' + tname
|
||||
return ['def {}(self) -> {}: ...'.format(opname, tname)]
|
||||
else:
|
||||
|
|
@ -275,7 +275,7 @@ def generate_type_hints(fname, decls, is_tensor=False):
|
|||
python_args += ["dtype: _dtype=None",
|
||||
"layout: layout=strided",
|
||||
"device: Union[_device, str, None]=None",
|
||||
"requires_grad:bool=False"]
|
||||
"requires_grad:_bool=False"]
|
||||
|
||||
python_args_s = ', '.join(python_args)
|
||||
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
|
||||
|
|
@ -317,9 +317,9 @@ def generate_type_hints(fname, decls, is_tensor=False):
|
|||
def gen_nn_modules(out):
|
||||
def replace_forward(m):
|
||||
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
|
||||
# would otherwise correctly point out that Module's descendants' `forward` declarations
|
||||
# conflict with `Module`s. Specificlaly, `Module` defines `forward(self, *args)` while the
|
||||
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
|
||||
# would otherwise correctly point out that Module's descendants' `forward` declarations
|
||||
# conflict with `Module`s. Specificlaly, `Module` defines `forward(self, *args)` while the
|
||||
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
|
||||
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
|
||||
forward_def = m.group(0) + " # type: ignore"
|
||||
call_def = re.sub(r'def forward', 'def __call__', forward_def)
|
||||
|
|
@ -412,7 +412,7 @@ def gen_pyi(declarations_path, out):
|
|||
|
||||
unsorted_function_hints = collections.defaultdict(list)
|
||||
unsorted_function_hints.update({
|
||||
'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
|
||||
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
|
||||
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
|
||||
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
|
||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||
|
|
@ -428,7 +428,7 @@ def gen_pyi(declarations_path, out):
|
|||
'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
|
||||
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'range': ['def range(start: Number, end: Number,'
|
||||
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
|
|
@ -503,28 +503,28 @@ def gen_pyi(declarations_path, out):
|
|||
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
|
||||
" -> None: ...".format(INDICES)],
|
||||
'tolist': ['def tolist(self) -> List: ...'],
|
||||
'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
|
||||
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
|
||||
'element_size': ['def element_size(self) -> _int: ...'],
|
||||
'dim': ['def dim(self) -> _int: ...'],
|
||||
'ndimension': ['def ndimension(self) -> _int: ...'],
|
||||
'nelement': ['def nelement(self) -> _int: ...'],
|
||||
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
|
||||
'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: _bool=False) -> Tensor: ...'],
|
||||
'numpy': ['def numpy(self) -> Any: ...'],
|
||||
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
|
||||
'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
|
||||
'storage': ['def storage(self) -> Storage: ...'],
|
||||
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
|
||||
'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: _bool=False)'
|
||||
' -> Union[str, Tensor]: ...'],
|
||||
'get_device': ['def get_device(self) -> _int: ...'],
|
||||
'contiguous': ['def contiguous(self) -> Tensor: ...'],
|
||||
'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
|
||||
'is_cuda': ['is_cuda: bool'],
|
||||
'is_leaf': ['is_leaf: bool'],
|
||||
'is_contiguous': ['def is_contiguous(self) -> _bool: ...'],
|
||||
'is_cuda': ['is_cuda: _bool'],
|
||||
'is_leaf': ['is_leaf: _bool'],
|
||||
'storage_offset': ['def storage_offset(self) -> _int: ...'],
|
||||
'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
|
||||
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
|
||||
'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
|
||||
'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
|
||||
'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
|
||||
],
|
||||
'item': ["def item(self) -> Number: ..."],
|
||||
})
|
||||
|
|
@ -541,7 +541,8 @@ def gen_pyi(declarations_path, out):
|
|||
'def {}(self, value: Number,'
|
||||
' other: Union[Tensor, Number]{})'
|
||||
' -> Tensor: ...'.format(name, out_suffix))
|
||||
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
|
||||
simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
|
||||
'half', 'int', 'long', 'short', 'bool']
|
||||
for name in simple_conversions:
|
||||
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
|
||||
|
||||
|
|
@ -568,11 +569,11 @@ def gen_pyi(declarations_path, out):
|
|||
# TODO: These are deprecated, maybe we shouldn't type hint them
|
||||
legacy_class_hints = []
|
||||
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage'):
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
|
||||
legacy_class_hints.append('class {}(Storage): ...'.format(c))
|
||||
|
||||
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor'):
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
|
||||
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
|
||||
|
||||
# Generate type signatures for dtype classes
|
||||
|
|
@ -584,7 +585,7 @@ def gen_pyi(declarations_path, out):
|
|||
for n in
|
||||
['float32', 'float', 'float64', 'double', 'float16', 'half',
|
||||
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
|
||||
'complex32', 'complex64', 'complex128', 'quint8', 'qint8', 'qint32']]
|
||||
'complex32', 'complex64', 'complex128', 'quint8', 'qint8', 'qint32', 'bool']]
|
||||
|
||||
# Write out the stub
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@ __all__ = [
|
|||
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
|
||||
'no_grad', 'enable_grad', 'rand', 'randn',
|
||||
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
|
||||
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'Tensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ per_tensor_affine: qscheme = ...
|
|||
# is necessary
|
||||
_int = builtins.int
|
||||
_float = builtins.float
|
||||
_bool = builtins.bool
|
||||
|
||||
class device:
|
||||
type: str
|
||||
|
|
@ -67,7 +68,7 @@ _qscheme = qscheme
|
|||
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float]
|
||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||
|
||||
class Generator:
|
||||
device: _device = ...
|
||||
|
|
@ -85,7 +86,7 @@ class Tensor:
|
|||
dtype: _dtype = ...
|
||||
shape: Size = ...
|
||||
device: _device = ...
|
||||
requires_grad: bool = ...
|
||||
requires_grad: _bool = ...
|
||||
grad: Optional[Tensor] = ...
|
||||
|
||||
${tensor_method_hints}
|
||||
|
|
@ -93,7 +94,7 @@ class Tensor:
|
|||
# Manually defined methods from torch/tensor.py
|
||||
def register_hook(self, hook: Callable) -> Any: ...
|
||||
def retain_grad(self) -> None: ...
|
||||
def is_shared(self) -> bool: ...
|
||||
def is_shared(self) -> _bool: ...
|
||||
def share_memory_(self) -> None: ...
|
||||
# TODO: fill in the types for these, or otherwise figure out some
|
||||
# way to not have to write these out again...
|
||||
|
|
@ -115,15 +116,14 @@ ${dtype_class_hints}
|
|||
# Pure Python functions defined in torch/__init__.py
|
||||
|
||||
def typename(obj) -> str: ...
|
||||
def is_tensor(obj) -> bool: ...
|
||||
def is_storage(obj) -> bool: ...
|
||||
def is_tensor(obj) -> _bool: ...
|
||||
def is_storage(obj) -> _bool: ...
|
||||
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
|
||||
def set_default_dtype(d : _dtype) -> None: ...
|
||||
def manager_path() -> str: ...
|
||||
def compiled_with_cxx11_abi() -> bool: ...
|
||||
def compiled_with_cxx11_abi() -> _bool: ...
|
||||
|
||||
# The return value of this function depends on the value of `as_tuple`,
|
||||
# (similar to `unique`, `lu`, etc.); as such, it is not
|
||||
# possible to type correctly
|
||||
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[bool]=None): ...
|
||||
|
||||
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...
|
||||
|
|
|
|||
|
|
@ -286,20 +286,20 @@ add_docstr_all('all',
|
|||
r"""
|
||||
.. function:: all() -> bool
|
||||
|
||||
Returns True if all elements in the tensor are non-zero, False otherwise.
|
||||
Returns True if all elements in the tensor are True, False otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(1, 3).byte() % 2
|
||||
>>> a = torch.rand(1, 2).bool()
|
||||
>>> a
|
||||
tensor([[1, 0, 0]], dtype=torch.uint8)
|
||||
tensor([[False, True]], dtype=torch.bool)
|
||||
>>> a.all()
|
||||
tensor(0, dtype=torch.uint8)
|
||||
tensor(False, dtype=torch.bool)
|
||||
|
||||
.. function:: all(dim, keepdim=False, out=None) -> Tensor
|
||||
|
||||
Returns True if all elements in each row of the tensor in the given
|
||||
dimension :attr:`dim` are non-zero, False otherwise.
|
||||
dimension :attr:`dim` are True, False otherwise.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
|
||||
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
|
||||
|
|
@ -313,15 +313,16 @@ Args:
|
|||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(4, 2).byte() % 2
|
||||
>>> a = torch.rand(4, 2).bool()
|
||||
>>> a
|
||||
tensor([[0, 0],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[1, 1]], dtype=torch.uint8)
|
||||
tensor([[True, True],
|
||||
[True, False],
|
||||
[True, True],
|
||||
[True, True]], dtype=torch.bool)
|
||||
>>> a.all(dim=1)
|
||||
tensor([0, 0, 0, 1], dtype=torch.uint8)
|
||||
|
||||
tensor([ True, False, True, True], dtype=torch.bool)
|
||||
>>> a.all(dim=0)
|
||||
tensor([ True, False], dtype=torch.bool)
|
||||
""")
|
||||
|
||||
add_docstr_all('allclose',
|
||||
|
|
@ -335,20 +336,19 @@ add_docstr_all('any',
|
|||
r"""
|
||||
.. function:: any() -> bool
|
||||
|
||||
Returns True if any elements in the tensor are non-zero, False otherwise.
|
||||
Returns True if any elements in the tensor are True, False otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(1, 3).byte() % 2
|
||||
>>> a = torch.rand(1, 2).bool()
|
||||
>>> a
|
||||
tensor([[0, 0, 1]], dtype=torch.uint8)
|
||||
tensor([[False, True]], dtype=torch.bool)
|
||||
>>> a.any()
|
||||
tensor(1, dtype=torch.uint8)
|
||||
|
||||
tensor(True, dtype=torch.bool)
|
||||
.. function:: any(dim, keepdim=False, out=None) -> Tensor
|
||||
|
||||
Returns True if any elements in each row of the tensor in the given
|
||||
dimension :attr:`dim` are non-zero, False otherwise.
|
||||
dimension :attr:`dim` are True, False otherwise.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
|
||||
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
|
||||
|
|
@ -362,15 +362,16 @@ Args:
|
|||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(4, 2).byte() % 2
|
||||
>>> a = torch.randn(4, 2) < 0
|
||||
>>> a
|
||||
tensor([[1, 0],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[0, 0]], dtype=torch.uint8)
|
||||
>>> a.any(dim=1)
|
||||
tensor([1, 0, 1, 0], dtype=torch.uint8)
|
||||
|
||||
tensor([[ True, True],
|
||||
[False, True],
|
||||
[ True, True],
|
||||
[False, False]])
|
||||
>>> a.any(1)
|
||||
tensor([ True, True, True, False])
|
||||
>>> a.any(0)
|
||||
tensor([True, True])
|
||||
""")
|
||||
|
||||
add_docstr_all('apply_',
|
||||
|
|
@ -1510,13 +1511,13 @@ add_docstr_all('masked_scatter_',
|
|||
masked_scatter_(mask, source)
|
||||
|
||||
Copies elements from :attr:`source` into :attr:`self` tensor at positions where
|
||||
the :attr:`mask` is one.
|
||||
the :attr:`mask` is True.
|
||||
The shape of :attr:`mask` must be :ref:`broadcastable <broadcasting-semantics>`
|
||||
with the shape of the underlying tensor. The :attr:`source` should have at least
|
||||
as many elements as the number of ones in :attr:`mask`
|
||||
|
||||
Args:
|
||||
mask (ByteTensor): the binary mask
|
||||
mask (BoolTensor): the boolean mask
|
||||
source (Tensor): the tensor to copy from
|
||||
|
||||
.. note::
|
||||
|
|
@ -1530,12 +1531,12 @@ add_docstr_all('masked_fill_',
|
|||
masked_fill_(mask, value)
|
||||
|
||||
Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is
|
||||
one. The shape of :attr:`mask` must be
|
||||
True. The shape of :attr:`mask` must be
|
||||
:ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
mask (ByteTensor): the binary mask
|
||||
mask (BoolTensor): the boolean mask
|
||||
value (float): the value to fill in with
|
||||
""")
|
||||
|
||||
|
|
|
|||
|
|
@ -1694,16 +1694,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor. Must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor. Must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 1, 0],
|
||||
[ 0, 1]], dtype=torch.uint8)
|
||||
tensor([[True, False], [False, True]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.equal,
|
||||
|
|
@ -2002,16 +2001,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor that must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor that must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 1, 1],
|
||||
[ 0, 1]], dtype=torch.uint8)
|
||||
tensor([[True, True], [False, True]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.geqrf,
|
||||
|
|
@ -2165,16 +2163,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor that must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor that must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 0, 1],
|
||||
[ 0, 0]], dtype=torch.uint8)
|
||||
tensor([[False, True], [False, False]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.histc,
|
||||
|
|
@ -2288,12 +2285,12 @@ Arguments:
|
|||
input (Tensor): A tensor to check
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location of `NaN` elements.
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.isnan(torch.tensor([1, float('nan'), 2]))
|
||||
tensor([ 0, 1, 0], dtype=torch.uint8)
|
||||
tensor([False, True, False])
|
||||
""")
|
||||
|
||||
add_docstr(torch.is_floating_point,
|
||||
|
|
@ -2359,16 +2356,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor that must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor that must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 1, 0],
|
||||
[ 1, 1]], dtype=torch.uint8)
|
||||
tensor([[True, False], [True, True]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.lerp,
|
||||
|
|
@ -2682,16 +2678,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor that must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor that must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A `torch.ByteTensor` containing a 1 at each location where comparison is true
|
||||
Tensor: A `torch.BoolTensor` containing a True at each location where comparison is true
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 0, 0],
|
||||
[ 1, 0]], dtype=torch.uint8)
|
||||
tensor([[False, False], [True, False]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.lu_solve,
|
||||
|
|
@ -2723,7 +2718,7 @@ add_docstr(torch.masked_select,
|
|||
masked_select(input, mask, out=None) -> Tensor
|
||||
|
||||
Returns a new 1-D tensor which indexes the :attr:`input` tensor according to
|
||||
the binary mask :attr:`mask` which is a `ByteTensor`.
|
||||
the boolean mask :attr:`mask` which is a `BoolTensor`.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need
|
||||
to match, but they must be :ref:`broadcastable <broadcasting-semantics>`.
|
||||
|
|
@ -2733,7 +2728,7 @@ to match, but they must be :ref:`broadcastable <broadcasting-semantics>`.
|
|||
|
||||
Args:
|
||||
input (Tensor): the input data
|
||||
mask (ByteTensor): the tensor containing the binary mask to index with
|
||||
mask (BoolTensor): the tensor containing the boolean mask to index with
|
||||
out (Tensor, optional): the output tensor
|
||||
|
||||
Example::
|
||||
|
|
@ -2745,9 +2740,9 @@ Example::
|
|||
[ 0.1307, -2.0608, 0.1244, 2.0139]])
|
||||
>>> mask = x.ge(0.5)
|
||||
>>> mask
|
||||
tensor([[ 0, 0, 0, 0],
|
||||
[ 0, 1, 1, 1],
|
||||
[ 0, 0, 0, 1]], dtype=torch.uint8)
|
||||
tensor([[False, False, False, False],
|
||||
[False, True, True, True],
|
||||
[False, False, False, True]])
|
||||
>>> torch.masked_select(x, mask)
|
||||
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
|
||||
""")
|
||||
|
|
@ -3495,16 +3490,15 @@ The second argument can be a number or a tensor whose shape is
|
|||
Args:
|
||||
input (Tensor): the tensor to compare
|
||||
other (Tensor or float): the tensor or value to compare
|
||||
out (Tensor, optional): the output tensor that must be a `ByteTensor`
|
||||
out (Tensor, optional): the output tensor that must be a `BoolTensor`
|
||||
|
||||
Returns:
|
||||
Tensor: A ``torch.ByteTensor`` containing a 1 at each location where comparison is true.
|
||||
Tensor: A ``torch.BoolTensor`` containing a True at each location where comparison is true.
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
|
||||
tensor([[ 0, 1],
|
||||
[ 1, 0]], dtype=torch.uint8)
|
||||
tensor([[False, True], [True, False]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.neg,
|
||||
|
|
@ -5987,9 +5981,9 @@ The operation is defined as:
|
|||
The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable <broadcasting-semantics>`.
|
||||
|
||||
Arguments:
|
||||
condition (ByteTensor): When True (nonzero), yield input, otherwise yield other
|
||||
input (Tensor): values selected at indices where :attr:`condition` is ``True``
|
||||
other (Tensor): values selected at indices where :attr:`condition` is ``False``
|
||||
condition (BoolTensor): When True (nonzero), yield x, otherwise yield y
|
||||
x (Tensor): values selected at indices where :attr:`condition` is ``True``
|
||||
y (Tensor): values selected at indices where :attr:`condition` is ``False``
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other`
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user