mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add nested squeeze.dim and unsqueeze (#86813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86813 Approved by: https://github.com/drisspg
This commit is contained in:
parent
e531cf7b2e
commit
ab69550678
|
|
@ -4893,6 +4893,7 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: squeeze
|
CompositeExplicitAutograd: squeeze
|
||||||
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: squeeze_nested
|
||||||
|
|
||||||
- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
@ -4901,6 +4902,7 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: squeeze
|
CompositeExplicitAutograd: squeeze
|
||||||
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
QuantizedCPU, QuantizedCUDA: squeeze_quantized
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested
|
||||||
tags: canonical
|
tags: canonical
|
||||||
|
|
||||||
- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
|
- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
|
||||||
|
|
@ -5531,6 +5533,7 @@
|
||||||
CompositeExplicitAutograd: unsqueeze
|
CompositeExplicitAutograd: unsqueeze
|
||||||
SparseCPU, SparseCUDA: unsqueeze_sparse
|
SparseCPU, SparseCUDA: unsqueeze_sparse
|
||||||
QuantizedCPU, QuantizedCUDA: unsqueeze_quantized
|
QuantizedCPU, QuantizedCUDA: unsqueeze_quantized
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: unsqueeze_nested
|
||||||
tags: canonical
|
tags: canonical
|
||||||
|
|
||||||
- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
|
- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
|
||||||
|
|
|
||||||
|
|
@ -1041,6 +1041,73 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||||
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor squeeze_nested(const Tensor& self) {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"squeeze(): For nested tensors, squeeze without the dim argument is not supported ",
|
||||||
|
"at the moment, however you can use squeeze(Tensor self, int dim) instead ",
|
||||||
|
"if you need this feature, please open an issue on github describing your use case.");
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
|
||||||
|
auto self_ptr = get_nested_tensor_impl(self);
|
||||||
|
int64_t ndim = self_ptr->dim();
|
||||||
|
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim);
|
||||||
|
TORCH_CHECK(wrapped_dim > 0,
|
||||||
|
"squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
|
||||||
|
"if you need this feature, please open an issue on github describing your use case.");
|
||||||
|
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||||
|
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
|
||||||
|
// if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
|
||||||
|
c10::optional<int64_t> size_dim = self_ptr->opt_size(dim);
|
||||||
|
if (!(size_dim.has_value() && size_dim.value() == 1)) {
|
||||||
|
// detach to avoid triggering throw_error_if_base_and_tensor_are_same
|
||||||
|
return self.detach();
|
||||||
|
}
|
||||||
|
// if ndim == 2 and we pass the above if statement we should have a
|
||||||
|
// nested tensor of singleton tensors
|
||||||
|
TORCH_CHECK(ndim != 2,
|
||||||
|
"squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
|
||||||
|
"supported at the moment, if you need this feature, please open an issue on github",
|
||||||
|
"describing your use case.");
|
||||||
|
auto column_indices = sizemat.new_empty(ndim - 2);
|
||||||
|
int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
|
||||||
|
std::iota(column_indices_ptr, column_indices_ptr + wrapped_dim - 1, 0);
|
||||||
|
std::iota(column_indices_ptr + wrapped_dim - 1, column_indices_ptr + ndim - 2, wrapped_dim);
|
||||||
|
auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
|
||||||
|
auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
|
||||||
|
return create_nested_view_tensor(
|
||||||
|
self, sizemat_squeezed, stridemat_squeezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
|
||||||
|
auto self_ptr = get_nested_tensor_impl(self);
|
||||||
|
int64_t ndim = self_ptr->dim();
|
||||||
|
int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim + 1);
|
||||||
|
TORCH_CHECK(wrapped_dim > 0,
|
||||||
|
"unsqueeze(): For nested tensors, unsqueezing dimension 0 is not supported at the moment ",
|
||||||
|
"if you need this feature, please open an issue on github describing your use case.");
|
||||||
|
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||||
|
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
|
||||||
|
auto mat_dim = wrapped_dim - 1;
|
||||||
|
Tensor new_size = sizemat.new_ones({sizemat.size(0), 1});
|
||||||
|
Tensor sizemat_unsqueezed = at::cat({sizemat.slice(1, 0, mat_dim),
|
||||||
|
new_size,
|
||||||
|
sizemat.slice(1, mat_dim, ndim)}, 1);
|
||||||
|
Tensor new_stride;
|
||||||
|
if (wrapped_dim == ndim) {
|
||||||
|
new_stride = stridemat.new_ones({stridemat.size(0), 1});
|
||||||
|
} else {
|
||||||
|
new_stride = (stridemat.select(1, mat_dim - 1) * sizemat.select(1, mat_dim - 1)).unsqueeze(-1);
|
||||||
|
}
|
||||||
|
Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
|
||||||
|
new_stride,
|
||||||
|
stridemat.slice(1, mat_dim, ndim)}, 1);
|
||||||
|
return create_nested_view_tensor(
|
||||||
|
self, sizemat_unsqueezed, stridemat_unsqueezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// utilities supporting `view_nested` and `reshape_nested`
|
// utilities supporting `view_nested` and `reshape_nested`
|
||||||
namespace {
|
namespace {
|
||||||
// Args:
|
// Args:
|
||||||
|
|
|
||||||
|
|
@ -606,7 +606,6 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
self.assertRaises(IndexError, lambda: nt[2])
|
self.assertRaises(IndexError, lambda: nt[2])
|
||||||
self.assertRaises(IndexError, lambda: nt[-3])
|
self.assertRaises(IndexError, lambda: nt[-3])
|
||||||
self.assertRaises(NotImplementedError, lambda: nt[:])
|
self.assertRaises(NotImplementedError, lambda: nt[:])
|
||||||
self.assertRaises(NotImplementedError, lambda: nt[None])
|
|
||||||
self.assertRaises(NotImplementedError, lambda: nt[...])
|
self.assertRaises(NotImplementedError, lambda: nt[...])
|
||||||
# tuple of indices: only support integer in the batch dimension
|
# tuple of indices: only support integer in the batch dimension
|
||||||
# + all possible indexing in the original tensor dimensions
|
# + all possible indexing in the original tensor dimensions
|
||||||
|
|
@ -1295,6 +1294,48 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
ptT = pt.transpose(-1, -2)
|
ptT = pt.transpose(-1, -2)
|
||||||
self.assertEqual(ptT, ptT_from_ntT)
|
self.assertEqual(ptT, ptT_from_ntT)
|
||||||
|
|
||||||
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
|
def test_squeeze_unsqueeze(self, device, dtype):
|
||||||
|
a = torch.arange(6).reshape(2, 3)
|
||||||
|
b = torch.arange(15).reshape(5, 3)
|
||||||
|
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
|
||||||
|
# error case: squeeze no dimension
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"For nested tensors, squeeze without the dim argument",
|
||||||
|
lambda: nt.squeeze()
|
||||||
|
)
|
||||||
|
# error case: squeeze nested dimension
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"For nested tensors, squeezing dimension 0",
|
||||||
|
lambda: nt.squeeze(0)
|
||||||
|
)
|
||||||
|
# error case: dimension out of range
|
||||||
|
self.assertRaises(IndexError, lambda: nt.squeeze(3))
|
||||||
|
# error case: squeeze nested tensor of singleton tensors
|
||||||
|
c = torch.ones(1)
|
||||||
|
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"For nested tensors, squeezing a nested tensor of singleton",
|
||||||
|
lambda: nt_singleton.squeeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# squeezing a dim which does not have size 1 should be a no-op
|
||||||
|
nt2 = nt.squeeze(-1)
|
||||||
|
self.assertEqual(nt, nt2)
|
||||||
|
|
||||||
|
# test cases that should work
|
||||||
|
for i in range(-2, 3):
|
||||||
|
if (i == 0):
|
||||||
|
continue
|
||||||
|
nt_unsqueezed = nt.unsqueeze(i)
|
||||||
|
size_idx = i if i < 0 else i - 1
|
||||||
|
self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long))
|
||||||
|
nt_squeezed = nt_unsqueezed.squeeze(i)
|
||||||
|
self.assertEqual(nt_squeezed, nt)
|
||||||
|
|
||||||
@dtypes(torch.float, torch.float16, torch.double)
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
def test_transpose_inference_mode_interaction(self, device, dtype):
|
def test_transpose_inference_mode_interaction(self, device, dtype):
|
||||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||||
|
|
@ -1767,6 +1808,52 @@ class TestNestedTensorAutograd(TestCase):
|
||||||
|
|
||||||
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||||
|
|
||||||
|
def test_nested_tensor_squeeze_backward(self):
|
||||||
|
nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
||||||
|
|
||||||
|
ynt = nt.squeeze(-1)
|
||||||
|
ypt = pt.squeeze(-1)
|
||||||
|
ynt.backward(ynt.clone())
|
||||||
|
ypt.backward(ypt.clone())
|
||||||
|
|
||||||
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||||
|
|
||||||
|
def test_nested_tensor_squeeze_gradcheck(self):
|
||||||
|
a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True)
|
||||||
|
b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True)
|
||||||
|
|
||||||
|
def grad_test_func(a, b):
|
||||||
|
nt = torch.nested.as_nested_tensor([a, b])
|
||||||
|
result = nt.squeeze(-1)
|
||||||
|
return torch.nested.to_padded_tensor(result, 0.0)
|
||||||
|
|
||||||
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
||||||
|
|
||||||
|
def test_nested_tensor_unsqueeze_backward(self):
|
||||||
|
nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
||||||
|
|
||||||
|
ynt = nt.unsqueeze(2)
|
||||||
|
ypt = pt.unsqueeze(2)
|
||||||
|
ynt.backward(ynt.clone())
|
||||||
|
ypt.backward(ypt.clone())
|
||||||
|
|
||||||
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
||||||
|
|
||||||
|
def test_nested_tensor_unsqueeze_gradcheck(self):
|
||||||
|
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True)
|
||||||
|
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True)
|
||||||
|
|
||||||
|
def grad_test_func(a, b):
|
||||||
|
nt = torch.nested.as_nested_tensor([a, b])
|
||||||
|
result = nt.unsqueeze(-1)
|
||||||
|
return torch.nested.to_padded_tensor(result, 0.0)
|
||||||
|
|
||||||
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
||||||
|
|
||||||
def test_nested_tensor_linear(self):
|
def test_nested_tensor_linear(self):
|
||||||
|
|
||||||
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
|
||||||
|
|
|
||||||
|
|
@ -1497,8 +1497,12 @@
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
||||||
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
|
||||||
self: unsqueeze_to(grad, dim, self.sym_sizes())
|
dispatch:
|
||||||
result: auto_linear
|
Default:
|
||||||
|
self: unsqueeze_to(grad, dim, self.sym_sizes())
|
||||||
|
result: auto_linear
|
||||||
|
AutogradNestedTensor:
|
||||||
|
self: grad.unsqueeze(dim)
|
||||||
|
|
||||||
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
|
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
|
||||||
self: unsqueeze_to(grad, self.sym_sizes())
|
self: unsqueeze_to(grad, self.sym_sizes())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user