Make tril_ and triu_ actually in-place (#17031)

Summary:
Currently, when the input tensor `self` is not contiguous, `tril_` and `triu_` calls `self = self.contiguous()`, which allocates a new contiguous tensor and assign it to `self`. This effectively changes the input tensor `self`'s pointer and will break downstream code after Variable/Tensor merge.

This PR fixes it so that `tril_` and `triu_` always update the input tensor in-place and preserve the input tensor's TensorImpl.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17031

Differential Revision: D14069592

Pulled By: yf225

fbshipit-source-id: d188218f426446a44ccc1d33fc28ac3f828c6a05
This commit is contained in:
Will Feng 2019-02-19 14:31:34 -08:00 committed by Facebook Github Bot
parent 0fc03d155a
commit c88798dbc1
4 changed files with 33 additions and 25 deletions

View File

@ -391,9 +391,9 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}
template <typename scalar_t, bool inplace, bool upper>
template <typename scalar_t, bool upper>
static void apply_triu_tril_single(
scalar_t* result, scalar_t* self,
scalar_t* result, scalar_t* self, bool inplace,
int64_t k, int64_t n, int64_t m,
int64_t res_row_stride, int64_t res_col_stride,
int64_t self_row_stride, int64_t self_col_stride) {
@ -428,8 +428,8 @@ static void apply_triu_tril_single(
}
}
template <typename scalar_t, bool inplace, bool upper>
void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
template <typename scalar_t, bool upper>
void apply_triu_tril(Tensor& result, const Tensor& self, bool inplace, int64_t k) {
auto n = self.size(-2);
auto m = self.size(-1);
auto self_data = self.data<scalar_t>();
@ -455,8 +455,8 @@ void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
for (b = 0; b < batchsize; b++) {
scalar_t* self_batch = &self_data[b * self_stride];
scalar_t* result_batch = &result_data[b * result_stride];
apply_triu_tril_single<scalar_t, inplace, upper>(
result_batch, self_batch, k, n, m,
apply_triu_tril_single<scalar_t, upper>(
result_batch, self_batch, inplace, k, n, m,
result_row_stride, result_column_stride, self_row_stride, self_column_stride);
}
}
@ -471,10 +471,13 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) {
if (self.numel() == 0) {
return self;
}
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
apply_triu_tril<scalar_t, true, false>(self, self, k);
apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
return self;
}
@ -487,7 +490,7 @@ Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
apply_triu_tril<scalar_t, false, false>(result, self_c, k);
apply_triu_tril<scalar_t, false>(result, self_c, false, k);
});
return result;
}
@ -502,10 +505,13 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) {
if (self.numel() == 0) {
return self;
}
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
apply_triu_tril<scalar_t, true, true>(self, self, k);
apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
return self;
}
@ -518,7 +524,7 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
apply_triu_tril<scalar_t, false, true>(result, self_c, k);
apply_triu_tril<scalar_t, true>(result, self_c, false, k);
});
return result;
}

View File

@ -505,8 +505,12 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c
}
Tensor& tril_cuda_(Tensor &self, int64_t k) {
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
return tril_cuda_out(self, self, k);
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
tril_cuda_out(result, self_c, k);
if (!inplace) self.copy_(result);
return self;
}
Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
@ -521,8 +525,12 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
}
Tensor& triu_cuda_(Tensor &self, int64_t k) {
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
return triu_cuda_out(self, self, k);
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
triu_cuda_out(result, self_c, k);
if (!inplace) self.copy_(result);
return self;
}
Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {

View File

@ -4149,18 +4149,14 @@ class _TestTorchMixin(object):
assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
x_nc_is_contiguous = x_nc.is_contiguous()
if upper:
self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
else:
self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
# any 3-dimensional tensor should be fine
if len(shape) <= 3 or s == -2:
self.assertFalse(x_nc.is_contiguous(),
"x_nc should remain non-contiguous")
elif s < -3:
self.assertTrue(x_nc.is_contiguous(),
"x_nc should become contiguous")
self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
"contiguity of x_nc should not be changed")
# expanded tensors
expanded_size = (x.size(0),) + x.size()

View File

@ -136,8 +136,6 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
# These functions are expected to change impl or storage of input tensors
'_th_set_', '_cudnn_rnn_flatten_weight',
# TODO: Fix these functions to update input tensor in-place
'tril_', 'triu_',
}
# END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]