[bc-breaking] Dispatch index_put with boolean mask argument to masked_fill (#61612)

Summary:
https://github.com/pytorch/pytorch/issues/57515

Based on ngimel 's branch, with a few tweaks to determine when to copy value tensors to device memory/additional tests.
bc-breaking note: Previously, if in `x[index]=value` `value` was a 0-d tensor with device different from `x`'s device, it resulted in a RuntimeError. Now this case is handled by copying `value` to the correct device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61612

Reviewed By: mrshenli

Differential Revision: D29753491

Pulled By: ngimel

fbshipit-source-id: 3fba14f4c2b9b136b50af020f9c1eda88f7373b0
This commit is contained in:
Eddie Yan 2021-07-19 22:52:04 -07:00 committed by Facebook GitHub Bot
parent 018dc4193e
commit 42d6543c7b
5 changed files with 75 additions and 6 deletions

View File

@ -46,11 +46,15 @@ static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, c
{
at::AutoDispatchBelowADInplaceOrView guard;
at::Device self_device = self.device();
// TODO: This qint special case looks very suspicious...
if (isQIntType(self.scalar_type())) {
value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));
} else if (self_device.is_cuda()) {
value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU));
} else {
value = at::indexing::scalarToTensor(v, self.options(), self.device());
value = at::indexing::scalarToTensor(v, self.options(), self_device);
}
}

View File

@ -352,6 +352,9 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) {
// appear. Users can workaround that case by dst[index..] = src.reshape(..)
dst.copy_(src);
return;
} else if (src.sizes().size() == 0 && src.device().type() == at::kCPU) {
dst.fill_(src.item());
return;
}
auto src_view = src.view(slicePrefix1sSize(src.sizes()));
c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");

View File

@ -290,6 +290,39 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
}
}
static std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<c10::optional<at::Tensor>>& indices,
const Tensor& value){
if (!(value.numel() ==1 && value.device().is_cpu())){
return std::make_tuple(false,Tensor());
}
int64_t num_ind = 0;
Tensor mask;
auto self_device = self.device();
for (const c10::optional<Tensor> i: indices) {
if (!i.has_value() || !(*i).defined()){
num_ind++;
} else {
Tensor index = std::move(*i);
if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
index.device() != self_device || mask.defined()){
return std::make_tuple(false, Tensor());
} else {
mask = index;
for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = num_ind + j;
TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
}
num_ind += mask.ndimension();
}
}
}
for (int64_t i = num_ind; i< self.ndimension(); i++){
mask = mask.unsqueeze(-1);
}
return std::make_tuple(true, mask);
}
static AdvancedIndex make_info(Tensor self, const torch::List<c10::optional<at::Tensor>>& orig) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
@ -470,6 +503,16 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>
"Please clone() the tensor before performing this operation. "
"This also applies to advanced indexing e.g. tensor[indices] = tensor");
}
if (!accumulate) {
auto masked_fill_dispatch = canDispatchToMaskedFill(self, indices, value);
if (std::get<0>(masked_fill_dispatch)) {
return self.masked_fill_(std::get<1>(masked_fill_dispatch), value.item());
}
}
auto value_ = value;
if (value.device() != self.device() && value.numel() == 1 && value.dim() == 0) {
value_ = value.to(self.device());
}
at::assert_no_overlap(self, value);
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
for (const c10::optional<Tensor>& index: indices) {
@ -477,16 +520,15 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>
at::assert_no_overlap(self, *index);
}
}
if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) {
TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ",
value.device(), " for value tensor");
index_put_with_sort_stub(self.device().type(), self, indices, value, accumulate, unsafe);
TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ",
value_.device(), " for value tensor");
index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe);
return self;
}
auto info = make_info(self, indices);
auto iter = make_index_put_iterator(info, value);
auto iter = make_index_put_iterator(info, value_);
index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate);
return self;
}

View File

@ -796,6 +796,24 @@ class TestIndexing(TestCase):
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
def test_jit_indexing(self, device):
def fn1(x):
x[x < 50] = 1.0
return x
def fn2(x):
x[0:50] = 1.0
return x
scripted_fn1 = torch.jit.script(fn1)
scripted_fn2 = torch.jit.script(fn2)
data = torch.arange(100, device=device, dtype=torch.float)
out = scripted_fn1(data.detach().clone())
ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
self.assertEqual(out, ref)
out = scripted_fn2(data.detach().clone())
self.assertEqual(out, ref)
def test_int_indices(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))

View File

@ -370,6 +370,8 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
// TODO: This qint special case looks very suspicious...
if (isQIntType(self_.scalar_type())) {
value = valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU));
} else if (self_device.is_cuda()) {
value = valueToTensor(self_.options(), py_value, at::Device(kCPU));
} else {
value = valueToTensor(self_.options(), py_value, self_device);
}