Add nested squeeze.dim and unsqueeze (#86813)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86813
Approved by: https://github.com/drisspg
This commit is contained in:
Mikayla Gawarecki 2022-10-12 22:31:13 +00:00 committed by PyTorch MergeBot
parent e531cf7b2e
commit ab69550678
4 changed files with 164 additions and 3 deletions

View File

@ -4893,6 +4893,7 @@
dispatch:
CompositeExplicitAutograd: squeeze
QuantizedCPU, QuantizedCUDA: squeeze_quantized
NestedTensorCPU, NestedTensorCUDA: squeeze_nested
- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
variants: function, method
@ -4901,6 +4902,7 @@
dispatch:
CompositeExplicitAutograd: squeeze
QuantizedCPU, QuantizedCUDA: squeeze_quantized
NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested
tags: canonical
- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
@ -5531,6 +5533,7 @@
CompositeExplicitAutograd: unsqueeze
SparseCPU, SparseCUDA: unsqueeze_sparse
QuantizedCPU, QuantizedCUDA: unsqueeze_quantized
NestedTensorCPU, NestedTensorCUDA: unsqueeze_nested
tags: canonical
- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)

View File

@ -1041,6 +1041,73 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
}
Tensor squeeze_nested(const Tensor& self) {
TORCH_CHECK(false,
"squeeze(): For nested tensors, squeeze without the dim argument is not supported ",
"at the moment, however you can use squeeze(Tensor self, int dim) instead ",
"if you need this feature, please open an issue on github describing your use case.");
return self;
}
Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
auto self_ptr = get_nested_tensor_impl(self);
int64_t ndim = self_ptr->dim();
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim);
TORCH_CHECK(wrapped_dim > 0,
"squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
"if you need this feature, please open an issue on github describing your use case.");
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
// if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
c10::optional<int64_t> size_dim = self_ptr->opt_size(dim);
if (!(size_dim.has_value() && size_dim.value() == 1)) {
// detach to avoid triggering throw_error_if_base_and_tensor_are_same
return self.detach();
}
// if ndim == 2 and we pass the above if statement we should have a
// nested tensor of singleton tensors
TORCH_CHECK(ndim != 2,
"squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
"supported at the moment, if you need this feature, please open an issue on github",
"describing your use case.");
auto column_indices = sizemat.new_empty(ndim - 2);
int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
std::iota(column_indices_ptr, column_indices_ptr + wrapped_dim - 1, 0);
std::iota(column_indices_ptr + wrapped_dim - 1, column_indices_ptr + ndim - 2, wrapped_dim);
auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
return create_nested_view_tensor(
self, sizemat_squeezed, stridemat_squeezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
}
Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
auto self_ptr = get_nested_tensor_impl(self);
int64_t ndim = self_ptr->dim();
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim + 1);
TORCH_CHECK(wrapped_dim > 0,
"unsqueeze(): For nested tensors, unsqueezing dimension 0 is not supported at the moment ",
"if you need this feature, please open an issue on github describing your use case.");
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
auto mat_dim = wrapped_dim - 1;
Tensor new_size = sizemat.new_ones({sizemat.size(0), 1});
Tensor sizemat_unsqueezed = at::cat({sizemat.slice(1, 0, mat_dim),
new_size,
sizemat.slice(1, mat_dim, ndim)}, 1);
Tensor new_stride;
if (wrapped_dim == ndim) {
new_stride = stridemat.new_ones({stridemat.size(0), 1});
} else {
new_stride = (stridemat.select(1, mat_dim - 1) * sizemat.select(1, mat_dim - 1)).unsqueeze(-1);
}
Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
new_stride,
stridemat.slice(1, mat_dim, ndim)}, 1);
return create_nested_view_tensor(
self, sizemat_unsqueezed, stridemat_unsqueezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
}
// utilities supporting `view_nested` and `reshape_nested`
namespace {
// Args:

View File

@ -606,7 +606,6 @@ class TestNestedTensorDeviceType(TestCase):
self.assertRaises(IndexError, lambda: nt[2])
self.assertRaises(IndexError, lambda: nt[-3])
self.assertRaises(NotImplementedError, lambda: nt[:])
self.assertRaises(NotImplementedError, lambda: nt[None])
self.assertRaises(NotImplementedError, lambda: nt[...])
# tuple of indices: only support integer in the batch dimension
# + all possible indexing in the original tensor dimensions
@ -1295,6 +1294,48 @@ class TestNestedTensorDeviceType(TestCase):
ptT = pt.transpose(-1, -2)
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
def test_squeeze_unsqueeze(self, device, dtype):
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(5, 3)
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
# error case: squeeze no dimension
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeeze without the dim argument",
lambda: nt.squeeze()
)
# error case: squeeze nested dimension
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeezing dimension 0",
lambda: nt.squeeze(0)
)
# error case: dimension out of range
self.assertRaises(IndexError, lambda: nt.squeeze(3))
# error case: squeeze nested tensor of singleton tensors
c = torch.ones(1)
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeezing a nested tensor of singleton",
lambda: nt_singleton.squeeze(1)
)
# squeezing a dim which does not have size 1 should be a no-op
nt2 = nt.squeeze(-1)
self.assertEqual(nt, nt2)
# test cases that should work
for i in range(-2, 3):
if (i == 0):
continue
nt_unsqueezed = nt.unsqueeze(i)
size_idx = i if i < 0 else i - 1
self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long))
nt_squeezed = nt_unsqueezed.squeeze(i)
self.assertEqual(nt_squeezed, nt)
@dtypes(torch.float, torch.float16, torch.double)
def test_transpose_inference_mode_interaction(self, device, dtype):
nt = self.random_nt(device, dtype, 4, (4, 4))
@ -1767,6 +1808,52 @@ class TestNestedTensorAutograd(TestCase):
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_squeeze_backward(self):
nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.squeeze(-1)
ypt = pt.squeeze(-1)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_squeeze_gradcheck(self):
a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True)
b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.squeeze(-1)
return torch.nested.to_padded_tensor(result, 0.0)
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
def test_nested_tensor_unsqueeze_backward(self):
nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.unsqueeze(2)
ypt = pt.unsqueeze(2)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_unsqueeze_gradcheck(self):
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True)
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.unsqueeze(-1)
return torch.nested.to_padded_tensor(result, 0.0)
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
def test_nested_tensor_linear(self):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)

View File

@ -1497,8 +1497,12 @@
result: auto_linear
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear
dispatch:
Default:
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear
AutogradNestedTensor:
self: grad.unsqueeze(dim)
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
self: unsqueeze_to(grad, self.sym_sizes())