From 4fcab92d6c38fdbb725e74de1c4b09cf54fa8132 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 15 Feb 2019 15:54:50 -0800 Subject: [PATCH] Move outplace ops to ATen (#16788) Summary: Based on https://github.com/pytorch/pytorch/pull/12413, with the following additional changes: - Inside `native_functions.yml` move those outplace operators right next to everyone's corresponding inplace operators for convenience of checking if they match when reviewing - `matches_jit_signature: True` for them - Add missing `scatter` with Scalar source - Add missing `masked_fill` and `index_fill` with Tensor source. - Add missing test for `scatter` with Scalar source - Add missing test for `masked_fill` and `index_fill` with Tensor source by checking the gradient w.r.t source - Add missing docs to `tensor.rst` Differential Revision: D14069925 Pulled By: ezyang fbshipit-source-id: bb3f0cb51cf6b756788dc4955667fead6e8796e5 --- aten/src/ATen/core/Tensor.h | 12 +++++- aten/src/ATen/core/TensorMethods.h | 34 ++++++++++++++- aten/src/ATen/core/Type.h | 12 +++++- aten/src/ATen/native/Indexing.cpp | 46 ++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 45 +++++++++++++++++++- docs/source/tensors.rst | 8 ++++ test/common_methods_invocations.py | 15 ++++--- test/test_jit.py | 2 +- test/test_torch.py | 12 ++++-- torch/__init__.pyi.in | 7 ---- torch/_tensor_docs.py | 49 ++++++++++++++++++++++ torch/tensor.py | 35 ---------------- 12 files changed, 218 insertions(+), 59 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 266315586f6..6c337fd9d75 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -392,8 +392,9 @@ class CAFFE2_API Tensor { Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntArrayRef signal_sizes={}) const; Tensor index(TensorList indices) const; Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source); - Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const; + Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const; Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false); + Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const; Tensor inverse() const; Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; bool is_distributed() const; @@ -559,16 +560,25 @@ class CAFFE2_API Tensor { Tensor & set_(); bool is_set_to(const Tensor & tensor) const; Tensor & masked_fill_(const Tensor & mask, Scalar value); + Tensor masked_fill(const Tensor & mask, Scalar value) const; Tensor & masked_fill_(const Tensor & mask, const Tensor & value); + Tensor masked_fill(const Tensor & mask, const Tensor & value) const; Tensor & masked_scatter_(const Tensor & mask, const Tensor & source); + Tensor masked_scatter(const Tensor & mask, const Tensor & source) const; Tensor view(IntArrayRef size) const; Tensor & put_(const Tensor & index, const Tensor & source, bool accumulate=false); Tensor & index_add_(int64_t dim, const Tensor & index, const Tensor & source); + Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const; Tensor & index_fill_(int64_t dim, const Tensor & index, Scalar value); + Tensor index_fill(int64_t dim, const Tensor & index, Scalar value) const; Tensor & index_fill_(int64_t dim, const Tensor & index, const Tensor & value); + Tensor index_fill(int64_t dim, const Tensor & index, const Tensor & value) const; Tensor & scatter_(int64_t dim, const Tensor & index, const Tensor & src); + Tensor scatter(int64_t dim, const Tensor & index, const Tensor & src) const; Tensor & scatter_(int64_t dim, const Tensor & index, Scalar value); + Tensor scatter(int64_t dim, const Tensor & index, Scalar value) const; Tensor & scatter_add_(int64_t dim, const Tensor & index, const Tensor & src); + Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const; Tensor & lt_(Scalar other); Tensor & lt_(const Tensor & other); Tensor & gt_(Scalar other); diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 8c1ef503bfd..973d905dbfd 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -310,12 +310,15 @@ inline Tensor Tensor::index(TensorList indices) const { inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) { return type().index_copy_(*this, dim, index, source); } -inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { - return type().index_put(*this, indices, values, accumulate); +inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { + return type().index_copy(*this, dim, index, source); } inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) { return type().index_put_(*this, indices, values, accumulate); } +inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { + return type().index_put(*this, indices, values, accumulate); +} inline Tensor Tensor::inverse() const { return type().inverse(*this); } @@ -811,12 +814,21 @@ inline bool Tensor::is_set_to(const Tensor & tensor) const { inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) { return type().masked_fill_(*this, mask, value); } +inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { + return type().masked_fill(*this, mask, value); +} inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) { return type().masked_fill_(*this, mask, value); } +inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const { + return type().masked_fill(*this, mask, value); +} inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) { return type().masked_scatter_(*this, mask, source); } +inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { + return type().masked_scatter(*this, mask, source); +} inline Tensor Tensor::view(IntArrayRef size) const { return type().view(*this, size); } @@ -826,21 +838,39 @@ inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool a inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) { return type().index_add_(*this, dim, index, source); } +inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { + return type().index_add(*this, dim, index, source); +} inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) { return type().index_fill_(*this, dim, index, value); } +inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const { + return type().index_fill(*this, dim, index, value); +} inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) { return type().index_fill_(*this, dim, index, value); } +inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const { + return type().index_fill(*this, dim, index, value); +} inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) { return type().scatter_(*this, dim, index, src); } +inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const { + return type().scatter(*this, dim, index, src); +} inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) { return type().scatter_(*this, dim, index, value); } +inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const { + return type().scatter(*this, dim, index, value); +} inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) { return type().scatter_add_(*this, dim, index, src); } +inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const { + return type().scatter_add(*this, dim, index, src); +} inline Tensor & Tensor::lt_(Scalar other) { return type().lt_(*this, other); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index e3d25d43b81..b90dd2629d9 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -272,8 +272,9 @@ struct CAFFE2_API Type { virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const = 0; virtual Tensor index(const Tensor & self, TensorList indices) const = 0; virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; - virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0; + virtual Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0; + virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0; virtual Tensor inverse(const Tensor & self) const = 0; virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0; virtual bool is_distributed(const Tensor & self) const = 0; @@ -439,16 +440,25 @@ struct CAFFE2_API Type { virtual Tensor & set_(Tensor & self) const = 0; virtual bool is_set_to(const Tensor & self, const Tensor & tensor) const = 0; virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, Scalar value) const = 0; + virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar value) const = 0; virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, const Tensor & value) const = 0; + virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & value) const = 0; virtual Tensor & masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) const = 0; + virtual Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) const = 0; virtual Tensor view(const Tensor & self, IntArrayRef size) const = 0; virtual Tensor & put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) const = 0; virtual Tensor & index_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; + virtual Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0; + virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0; virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0; + virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0; virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0; + virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0; virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0; + virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0; virtual Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0; + virtual Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0; virtual Tensor & lt_(Tensor & self, Scalar other) const = 0; virtual Tensor & lt_(Tensor & self, const Tensor & other) const = 0; virtual Tensor & gt_(Tensor & self, Scalar other) const = 0; diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index 12cdd489928..709124cf737 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -498,4 +498,50 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten return at::legacy::th::_th_index_copy_(self, dim, index, source); } +Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + return self.clone().index_copy_(dim, index, source); +} + +Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + return self.clone().index_add_(dim, index, source); +} + +Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) { + return self.clone().index_fill_(dim, index, source); +} + +Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + return self.clone().index_fill_(dim, index, source); +} + +Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + return self.clone().scatter_(dim, index, source); +} + +Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) { + return self.clone().scatter_(dim, index, source); +} + +Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + return self.clone().scatter_add_(dim, index, source); +} + +Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) { + Tensor _mask, _self; + std::tie(_mask, _self) = expand_outplace(mask, self); + return _self.clone().masked_scatter_(_mask, source); +} + +Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) { + Tensor _mask, _self; + std::tie(_mask, _self) = expand_outplace(mask, self); + return _self.clone().masked_fill_(mask, source); +} + +Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) { + Tensor _mask, _self; + std::tie(_mask, _self) = expand_outplace(mask, self); + return _self.clone().masked_fill_(mask, source); +} + }} // at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 56f1ef35dce..8ec6170f00d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1107,10 +1107,11 @@ variants: function, method # NB: This function is special-cased in tools/autograd/gen_variable_type.py -- func: index_copy_(Tensor(a!) self, int dim, IndexTensor index, Tensor source) -> Tensor(a!) +- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + matches_jit_signature: True variants: method -- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor +- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor matches_jit_signature: True variants: function, method @@ -1118,6 +1119,10 @@ matches_jit_signature: True variants: function, method +- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + matches_jit_signature: True + variants: function, method + - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor matches_jit_signature: True variants: function @@ -3056,14 +3061,26 @@ matches_jit_signature: True variants: method +- func: masked_fill(Tensor self, Tensor mask, Scalar value) -> Tensor + matches_jit_signature: True + variants: function, method + - func: masked_fill_(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) matches_jit_signature: True variants: method +- func: masked_fill(Tensor self, Tensor mask, Tensor value) -> Tensor + matches_jit_signature: True + variants: function, method + - func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) matches_jit_signature: True variants: method +- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + matches_jit_signature: True + variants: function, method + - func: view(Tensor(a) self, int[] size) -> Tensor(a) matches_jit_signature: True variants: method @@ -3077,26 +3094,50 @@ matches_jit_signature: True variants: method +- func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + matches_jit_signature: True + variants: function, method + - func: index_fill_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) matches_jit_signature: True variants: method +- func: index_fill(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + matches_jit_signature: True + variants: function, method + - func: index_fill_(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) matches_jit_signature: True variants: method +- func: index_fill(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + matches_jit_signature: True + variants: function, method + - func: scatter_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) matches_jit_signature: True variants: method +- func: scatter(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + matches_jit_signature: True + variants: function, method + - func: scatter_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) matches_jit_signature: True variants: method +- func: scatter(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + matches_jit_signature: True + variants: function, method + - func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) matches_jit_signature: True variants: method +- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + matches_jit_signature: True + variants: function, method + - func: lt_(Tensor(a!) self, Scalar other) -> Tensor(a!) matches_jit_signature: True variants: method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 4e893920820..cb0aae03811 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -252,9 +252,13 @@ view of a storage and defines numeric operations on it. .. automethod:: half .. automethod:: histc .. automethod:: index_add_ + .. automethod:: index_add .. automethod:: index_copy_ + .. automethod:: index_copy .. automethod:: index_fill_ + .. automethod:: index_fill .. automethod:: index_put_ + .. automethod:: index_put .. automethod:: index_select .. automethod:: int .. automethod:: inverse @@ -285,7 +289,9 @@ view of a storage and defines numeric operations on it. .. automethod:: lt_ .. automethod:: map_ .. automethod:: masked_scatter_ + .. automethod:: masked_scatter .. automethod:: masked_fill_ + .. automethod:: masked_fill .. automethod:: masked_select .. automethod:: matmul .. automethod:: matrix_power @@ -346,7 +352,9 @@ view of a storage and defines numeric operations on it. .. automethod:: rsqrt .. automethod:: rsqrt_ .. automethod:: scatter_ + .. automethod:: scatter .. automethod:: scatter_add_ + .. automethod:: scatter_add .. automethod:: select .. automethod:: set_ .. automethod:: share_memory_ diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 0c8be4ced4e..d7ae03ce354 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -728,7 +728,8 @@ def method_tests(): ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]), ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), - ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]), + ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]), + ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]), ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]), @@ -741,15 +742,17 @@ def method_tests(): ('masked_select', (M, M), (torch.tensor(1, dtype=torch.uint8),), 'scalar_broadcast_rhs'), ('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'), ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)), - ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'), - # no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace + ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), ()), 'tensor'), + ('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'), ('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'), - ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'), - ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)), + ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10), 'scalar'), + ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), ()), 'scalar_variable'), - ('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), + ('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10), 'scalar_broadcast_rhs'), ('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))), + ('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)), + 'broadcast_lhs'), ('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)), 'broadcast_rhs'), ('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'), diff --git a/test/test_jit.py b/test/test_jit.py index 0c89ae22129..fbb60fa37e8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9165,7 +9165,7 @@ a") def test_builtin_error_messsage(self): from torch.nn.modules.utils import _single, _pair, _triple, _quadruple - with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"): + with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): @torch.jit.script def close_match(x): return x.masked_fill(True) diff --git a/test/test_torch.py b/test/test_torch.py index d6269f25683..d8e8ebd1eed 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3455,6 +3455,7 @@ class _TestTorchMixin(object): for fn in fns: (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() + full1d = cast(torch.randn(*dims_full).flatten().float()) small = cast(torch.randn(*dims_small).float()) large = cast(torch.randn(*dims_large).float()) small_expanded = small.expand(*dims_full) @@ -3471,8 +3472,7 @@ class _TestTorchMixin(object): # map and map2 are not implementd on CUDA tensors continue - # TODO: fix masked_scatter and masked_fill broadcasting - if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']: + if hasattr(large_expanded, fn): # run through tensor versions of functions # and verify fully expanded inputs give same results expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} @@ -3482,6 +3482,10 @@ class _TestTorchMixin(object): return myfn(t1, 0.5) elif fn == "masked_select": return myfn(t1 < 0) + elif fn == "masked_scatter": + return myfn(t1 < 0.5, full1d) + elif fn == "masked_fill": + return myfn(t1 < 0.5, 1.0) elif fn in fns_3_args: return myfn(1, t1, t2) else: @@ -3509,7 +3513,7 @@ class _TestTorchMixin(object): elif fn == "masked_select": return fntorch(t1, t2 < 0) elif fn == "masked_scatter": - return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float())) + return fntorch(t1, t2 < 0.5, full1d) elif fn == "masked_fill": return fntorch(t1, t2 < 0.5, 1.0) elif fn in fns_3_args: @@ -3540,7 +3544,7 @@ class _TestTorchMixin(object): if fn == "lerp": return t0_fn(t1, 0.5) elif fn == "masked_scatter": - return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float())) + return t0_fn(t1 < 0.5, full1d) elif fn == "masked_fill": return t0_fn(t1 < 0.5, 1.0) elif fn == "map": diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index 348e3bad42b..aa3c1f0832e 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -80,13 +80,6 @@ class Tensor: def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): ... def split(self, split_size, dim=0): ... - def index_add(self, dim, index, tensor): ... - def index_copy(self, dim, index, tensor): ... - def index_fill(self, dim, index, value): ... - def scatter(self, dim, index, source): ... - def scatter_add(self, dim, index, source): ... - def masked_scatter(self, mask, tensor): ... - def masked_fill(self, mask, value): ... def unique(self, sorted=True, return_inverse=False, dim=None): ... ${function_hints} diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index f56b053f009..f524e56c447 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2963,6 +2963,55 @@ pinverse() -> Tensor See :func:`torch.pinverse` """) +add_docstr_all('index_add', + r""" +index_add(dim, index, tensor) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_add_` +""") + +add_docstr_all('index_copy', + r""" +index_copy(dim, index, tensor) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_copy_` +""") + +add_docstr_all('index_fill', + r""" +index_fill(dim, index, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_fill_` +""") + +add_docstr_all('scatter', + r""" +scatter(dim, index, source) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""") + +add_docstr_all('scatter_add', + r""" +scatter_add(dim, index, source) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""") + +add_docstr_all('masked_scatter', + r""" +masked_scatter(mask, tensor) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_scatter_` +""") + +add_docstr_all('masked_fill', + r""" +masked_fill(mask, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_fill_` +""") + add_docstr_all('grad', r""" This attribute is ``None`` by default and becomes a Tensor the first time a call to diff --git a/torch/tensor.py b/torch/tensor.py index cb31ac824e2..2e1322c4632 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -307,41 +307,6 @@ class Tensor(torch._C._TensorBase): else: return super(Tensor, self).split_with_sizes(split_size, dim) - def index_add(self, dim, index, tensor): - r"""Out-of-place version of :meth:`torch.Tensor.index_add_` - """ - return self.clone().index_add_(dim, index, tensor) - - def index_copy(self, dim, index, tensor): - r"""Out-of-place version of :meth:`torch.Tensor.index_copy_` - """ - return self.clone().index_copy_(dim, index, tensor) - - def index_fill(self, dim, index, value): - r"""Out-of-place version of :meth:`torch.Tensor.index_fill_` - """ - return self.clone().index_fill_(dim, index, value) - - def scatter(self, dim, index, source): - r"""Out-of-place version of :meth:`torch.Tensor.scatter_` - """ - return self.clone().scatter_(dim, index, source) - - def scatter_add(self, dim, index, source): - r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_` - """ - return self.clone().scatter_add_(dim, index, source) - - def masked_scatter(self, mask, tensor): - r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_` - """ - return self.clone().masked_scatter_(mask, tensor) - - def masked_fill(self, mask, value): - r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_` - """ - return self.clone().masked_fill_(mask, value) - def unique(self, sorted=True, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the tensor as a 1-D tensor.