mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add nested squeeze.dim and unsqueeze (#86813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86813 Approved by: https://github.com/drisspg
This commit is contained in:
parent
e531cf7b2e
commit
ab69550678
|
|
@ -4893,6 +4893,7 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: squeeze
|
||||
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
||||
NestedTensorCPU, NestedTensorCUDA: squeeze_nested
|
||||
|
||||
- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
|
@ -4901,6 +4902,7 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: squeeze
|
||||
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
||||
NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested
|
||||
tags: canonical
|
||||
|
||||
- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
|
||||
|
|
@ -5531,6 +5533,7 @@
|
|||
CompositeExplicitAutograd: unsqueeze
|
||||
SparseCPU, SparseCUDA: unsqueeze_sparse
|
||||
QuantizedCPU, QuantizedCUDA: unsqueeze_quantized
|
||||
NestedTensorCPU, NestedTensorCUDA: unsqueeze_nested
|
||||
tags: canonical
|
||||
|
||||
- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -1041,6 +1041,73 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
|
|||
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||
}
|
||||
|
||||
Tensor squeeze_nested(const Tensor& self) {
|
||||
TORCH_CHECK(false,
|
||||
"squeeze(): For nested tensors, squeeze without the dim argument is not supported ",
|
||||
"at the moment, however you can use squeeze(Tensor self, int dim) instead ",
|
||||
"if you need this feature, please open an issue on github describing your use case.");
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
int64_t ndim = self_ptr->dim();
|
||||
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim);
|
||||
TORCH_CHECK(wrapped_dim > 0,
|
||||
"squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
|
||||
"if you need this feature, please open an issue on github describing your use case.");
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
|
||||
// if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
|
||||
c10::optional<int64_t> size_dim = self_ptr->opt_size(dim);
|
||||
if (!(size_dim.has_value() && size_dim.value() == 1)) {
|
||||
// detach to avoid triggering throw_error_if_base_and_tensor_are_same
|
||||
return self.detach();
|
||||
}
|
||||
// if ndim == 2 and we pass the above if statement we should have a
|
||||
// nested tensor of singleton tensors
|
||||
TORCH_CHECK(ndim != 2,
|
||||
"squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
|
||||
"supported at the moment, if you need this feature, please open an issue on github",
|
||||
"describing your use case.");
|
||||
auto column_indices = sizemat.new_empty(ndim - 2);
|
||||
int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
|
||||
std::iota(column_indices_ptr, column_indices_ptr + wrapped_dim - 1, 0);
|
||||
std::iota(column_indices_ptr + wrapped_dim - 1, column_indices_ptr + ndim - 2, wrapped_dim);
|
||||
auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
|
||||
auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
|
||||
return create_nested_view_tensor(
|
||||
self, sizemat_squeezed, stridemat_squeezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||
}
|
||||
|
||||
Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
int64_t ndim = self_ptr->dim();
|
||||
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim + 1);
|
||||
TORCH_CHECK(wrapped_dim > 0,
|
||||
"unsqueeze(): For nested tensors, unsqueezing dimension 0 is not supported at the moment ",
|
||||
"if you need this feature, please open an issue on github describing your use case.");
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
|
||||
auto mat_dim = wrapped_dim - 1;
|
||||
Tensor new_size = sizemat.new_ones({sizemat.size(0), 1});
|
||||
Tensor sizemat_unsqueezed = at::cat({sizemat.slice(1, 0, mat_dim),
|
||||
new_size,
|
||||
sizemat.slice(1, mat_dim, ndim)}, 1);
|
||||
Tensor new_stride;
|
||||
if (wrapped_dim == ndim) {
|
||||
new_stride = stridemat.new_ones({stridemat.size(0), 1});
|
||||
} else {
|
||||
new_stride = (stridemat.select(1, mat_dim - 1) * sizemat.select(1, mat_dim - 1)).unsqueeze(-1);
|
||||
}
|
||||
Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
|
||||
new_stride,
|
||||
stridemat.slice(1, mat_dim, ndim)}, 1);
|
||||
return create_nested_view_tensor(
|
||||
self, sizemat_unsqueezed, stridemat_unsqueezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||
}
|
||||
|
||||
|
||||
// utilities supporting `view_nested` and `reshape_nested`
|
||||
namespace {
|
||||
// Args:
|
||||
|
|
|
|||
|
|
@ -606,7 +606,6 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
self.assertRaises(IndexError, lambda: nt[2])
|
||||
self.assertRaises(IndexError, lambda: nt[-3])
|
||||
self.assertRaises(NotImplementedError, lambda: nt[:])
|
||||
self.assertRaises(NotImplementedError, lambda: nt[None])
|
||||
self.assertRaises(NotImplementedError, lambda: nt[...])
|
||||
# tuple of indices: only support integer in the batch dimension
|
||||
# + all possible indexing in the original tensor dimensions
|
||||
|
|
@ -1295,6 +1294,48 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
ptT = pt.transpose(-1, -2)
|
||||
self.assertEqual(ptT, ptT_from_ntT)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_squeeze_unsqueeze(self, device, dtype):
|
||||
a = torch.arange(6).reshape(2, 3)
|
||||
b = torch.arange(15).reshape(5, 3)
|
||||
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
|
||||
# error case: squeeze no dimension
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"For nested tensors, squeeze without the dim argument",
|
||||
lambda: nt.squeeze()
|
||||
)
|
||||
# error case: squeeze nested dimension
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"For nested tensors, squeezing dimension 0",
|
||||
lambda: nt.squeeze(0)
|
||||
)
|
||||
# error case: dimension out of range
|
||||
self.assertRaises(IndexError, lambda: nt.squeeze(3))
|
||||
# error case: squeeze nested tensor of singleton tensors
|
||||
c = torch.ones(1)
|
||||
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"For nested tensors, squeezing a nested tensor of singleton",
|
||||
lambda: nt_singleton.squeeze(1)
|
||||
)
|
||||
|
||||
# squeezing a dim which does not have size 1 should be a no-op
|
||||
nt2 = nt.squeeze(-1)
|
||||
self.assertEqual(nt, nt2)
|
||||
|
||||
# test cases that should work
|
||||
for i in range(-2, 3):
|
||||
if (i == 0):
|
||||
continue
|
||||
nt_unsqueezed = nt.unsqueeze(i)
|
||||
size_idx = i if i < 0 else i - 1
|
||||
self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long))
|
||||
nt_squeezed = nt_unsqueezed.squeeze(i)
|
||||
self.assertEqual(nt_squeezed, nt)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_transpose_inference_mode_interaction(self, device, dtype):
|
||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||
|
|
@ -1767,6 +1808,52 @@ class TestNestedTensorAutograd(TestCase):
|
|||
|
||||
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||
|
||||
def test_nested_tensor_squeeze_backward(self):
|
||||
nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True)
|
||||
with torch.no_grad():
|
||||
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
||||
|
||||
ynt = nt.squeeze(-1)
|
||||
ypt = pt.squeeze(-1)
|
||||
ynt.backward(ynt.clone())
|
||||
ypt.backward(ypt.clone())
|
||||
|
||||
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||
|
||||
def test_nested_tensor_squeeze_gradcheck(self):
|
||||
a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True)
|
||||
b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True)
|
||||
|
||||
def grad_test_func(a, b):
|
||||
nt = torch.nested.as_nested_tensor([a, b])
|
||||
result = nt.squeeze(-1)
|
||||
return torch.nested.to_padded_tensor(result, 0.0)
|
||||
|
||||
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
||||
|
||||
def test_nested_tensor_unsqueeze_backward(self):
|
||||
nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True)
|
||||
with torch.no_grad():
|
||||
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
||||
|
||||
ynt = nt.unsqueeze(2)
|
||||
ypt = pt.unsqueeze(2)
|
||||
ynt.backward(ynt.clone())
|
||||
ypt.backward(ypt.clone())
|
||||
|
||||
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||
|
||||
def test_nested_tensor_unsqueeze_gradcheck(self):
|
||||
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True)
|
||||
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True)
|
||||
|
||||
def grad_test_func(a, b):
|
||||
nt = torch.nested.as_nested_tensor([a, b])
|
||||
result = nt.unsqueeze(-1)
|
||||
return torch.nested.to_padded_tensor(result, 0.0)
|
||||
|
||||
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
||||
|
||||
def test_nested_tensor_linear(self):
|
||||
|
||||
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
|
||||
|
|
|
|||
|
|
@ -1497,8 +1497,12 @@
|
|||
result: auto_linear
|
||||
|
||||
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
||||
dispatch:
|
||||
Default:
|
||||
self: unsqueeze_to(grad, dim, self.sym_sizes())
|
||||
result: auto_linear
|
||||
AutogradNestedTensor:
|
||||
self: grad.unsqueeze(dim)
|
||||
|
||||
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
|
||||
self: unsqueeze_to(grad, self.sym_sizes())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user