mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make tril_ and triu_ actually in-place (#17031)
Summary: Currently, when the input tensor `self` is not contiguous, `tril_` and `triu_` calls `self = self.contiguous()`, which allocates a new contiguous tensor and assign it to `self`. This effectively changes the input tensor `self`'s pointer and will break downstream code after Variable/Tensor merge. This PR fixes it so that `tril_` and `triu_` always update the input tensor in-place and preserve the input tensor's TensorImpl. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17031 Differential Revision: D14069592 Pulled By: yf225 fbshipit-source-id: d188218f426446a44ccc1d33fc28ac3f828c6a05
This commit is contained in:
parent
0fc03d155a
commit
c88798dbc1
|
|
@ -391,9 +391,9 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool inplace, bool upper>
|
||||
template <typename scalar_t, bool upper>
|
||||
static void apply_triu_tril_single(
|
||||
scalar_t* result, scalar_t* self,
|
||||
scalar_t* result, scalar_t* self, bool inplace,
|
||||
int64_t k, int64_t n, int64_t m,
|
||||
int64_t res_row_stride, int64_t res_col_stride,
|
||||
int64_t self_row_stride, int64_t self_col_stride) {
|
||||
|
|
@ -428,8 +428,8 @@ static void apply_triu_tril_single(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool inplace, bool upper>
|
||||
void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
|
||||
template <typename scalar_t, bool upper>
|
||||
void apply_triu_tril(Tensor& result, const Tensor& self, bool inplace, int64_t k) {
|
||||
auto n = self.size(-2);
|
||||
auto m = self.size(-1);
|
||||
auto self_data = self.data<scalar_t>();
|
||||
|
|
@ -455,8 +455,8 @@ void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
|
|||
for (b = 0; b < batchsize; b++) {
|
||||
scalar_t* self_batch = &self_data[b * self_stride];
|
||||
scalar_t* result_batch = &result_data[b * result_stride];
|
||||
apply_triu_tril_single<scalar_t, inplace, upper>(
|
||||
result_batch, self_batch, k, n, m,
|
||||
apply_triu_tril_single<scalar_t, upper>(
|
||||
result_batch, self_batch, inplace, k, n, m,
|
||||
result_row_stride, result_column_stride, self_row_stride, self_column_stride);
|
||||
}
|
||||
}
|
||||
|
|
@ -471,10 +471,13 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) {
|
|||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
|
||||
bool inplace = checkTrilTriuBatchContiguous(self);
|
||||
Tensor self_c = inplace ? self : self.contiguous();
|
||||
Tensor result = inplace ? self : at::empty_like(self);
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
|
||||
apply_triu_tril<scalar_t, true, false>(self, self, k);
|
||||
apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
|
||||
});
|
||||
if (!inplace) self.copy_(result);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
|
@ -487,7 +490,7 @@ Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
|
|||
}
|
||||
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
|
||||
apply_triu_tril<scalar_t, false, false>(result, self_c, k);
|
||||
apply_triu_tril<scalar_t, false>(result, self_c, false, k);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
|
@ -502,10 +505,13 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) {
|
|||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
|
||||
bool inplace = checkTrilTriuBatchContiguous(self);
|
||||
Tensor self_c = inplace ? self : self.contiguous();
|
||||
Tensor result = inplace ? self : at::empty_like(self);
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
|
||||
apply_triu_tril<scalar_t, true, true>(self, self, k);
|
||||
apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
|
||||
});
|
||||
if (!inplace) self.copy_(result);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
|
@ -518,7 +524,7 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
|
|||
}
|
||||
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
|
||||
apply_triu_tril<scalar_t, false, true>(result, self_c, k);
|
||||
apply_triu_tril<scalar_t, true>(result, self_c, false, k);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -505,8 +505,12 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c
|
|||
}
|
||||
|
||||
Tensor& tril_cuda_(Tensor &self, int64_t k) {
|
||||
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
|
||||
return tril_cuda_out(self, self, k);
|
||||
bool inplace = checkTrilTriuBatchContiguous(self);
|
||||
Tensor self_c = inplace ? self : self.contiguous();
|
||||
Tensor result = inplace ? self : at::empty_like(self);
|
||||
tril_cuda_out(result, self_c, k);
|
||||
if (!inplace) self.copy_(result);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
|
||||
|
|
@ -521,8 +525,12 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
|
|||
}
|
||||
|
||||
Tensor& triu_cuda_(Tensor &self, int64_t k) {
|
||||
if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
|
||||
return triu_cuda_out(self, self, k);
|
||||
bool inplace = checkTrilTriuBatchContiguous(self);
|
||||
Tensor self_c = inplace ? self : self.contiguous();
|
||||
Tensor result = inplace ? self : at::empty_like(self);
|
||||
triu_cuda_out(result, self_c, k);
|
||||
if (!inplace) self.copy_(result);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
|
||||
|
|
|
|||
|
|
@ -4149,18 +4149,14 @@ class _TestTorchMixin(object):
|
|||
assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
|
||||
exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
|
||||
self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
|
||||
x_nc_is_contiguous = x_nc.is_contiguous()
|
||||
if upper:
|
||||
self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
|
||||
else:
|
||||
self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
|
||||
|
||||
# any 3-dimensional tensor should be fine
|
||||
if len(shape) <= 3 or s == -2:
|
||||
self.assertFalse(x_nc.is_contiguous(),
|
||||
"x_nc should remain non-contiguous")
|
||||
elif s < -3:
|
||||
self.assertTrue(x_nc.is_contiguous(),
|
||||
"x_nc should become contiguous")
|
||||
self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
|
||||
"contiguity of x_nc should not be changed")
|
||||
|
||||
# expanded tensors
|
||||
expanded_size = (x.size(0),) + x.size()
|
||||
|
|
|
|||
|
|
@ -136,8 +136,6 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
|||
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
|
||||
# These functions are expected to change impl or storage of input tensors
|
||||
'_th_set_', '_cudnn_rnn_flatten_weight',
|
||||
# TODO: Fix these functions to update input tensor in-place
|
||||
'tril_', 'triu_',
|
||||
}
|
||||
# END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user