Allow only one -1 in nested view/reshape (#85691)

###  this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension

Behavior before this PR:

1. `-1` allowed for implicit batch dimension
2. multiple `-1`s allowed for pre-existing dimensions
3.  for new dimensions, `-1` is not allowed

 it is worth noting that for the most part 3 is basically unreachable because assuming a nested tensor has at least 1 ragged dimension, you would expect at least one -1 to be in the proposed shape for the pre-existing dimensions

Behavior after this PR:
1. batch dimension **must be specified**
2. **only one** `-1` allowed for pre-existing dimensions **this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension**
3. unchanged

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85691
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Mikayla Gawarecki 2022-09-28 20:30:37 +00:00 committed by PyTorch MergeBot
parent 1418a663b1
commit 01add6e288
2 changed files with 37 additions and 39 deletions

View File

@ -997,6 +997,8 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
& stride = strides[itensor];
// compute reshaped size
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
// only allow one pre-existing dimension to have proposed shape == -1
int64_t infer_index_old = -1;
// some negative sizes remain to be infered
if (ndims_underlying < ndims_underlying_reshaped) {
int64_t numel = 1, numel_reshaped = 1;
@ -1005,7 +1007,9 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
int64_t& size_reshaped = size_reshaped_vector[idim];
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
if (size_reshaped == -1) {
TORCH_CHECK(infer_index_old == -1, "only one dimension can be inferred");
size_reshaped = size[idim];
infer_index_old = idim;
}
numel *= size[idim];
numel_reshaped *= size_reshaped;
@ -1087,7 +1091,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
// Note [Special size rule for nested tensor]
// Instead of infering size, -1 means "inherit the old size", so:
// * negative size is legal for a ragged dimension
// * multiple sizes can be -1
// * however, we only allow one -1
// In principle we could still infer a dimension,
// we are designing a better semantics to include both inheritance and inference
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
@ -1101,19 +1105,10 @@ Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
ntensors > 0,
"empty nested tensor cannot be reshaped");
// basic information after reshaping
int64_t ntensors_reshaped;
if (proposed_shape[0] >= 0) {
ntensors_reshaped = proposed_shape[0];
}
else if (proposed_shape[0] == -1) {
ntensors_reshaped = ntensors;
}
else {
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
}
int64_t ntensors_reshaped = proposed_shape[0];
TORCH_CHECK(
ntensors == ntensors_reshaped,
"for now view cannot change the implicit batch dimension");
"view: For now nested view cannot change or infer the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
// reshaping underlying tensor dimensions does not change offset
@ -1197,19 +1192,10 @@ Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
ntensors > 0,
"empty nested tensor cannot be reshaped");
// basic information after reshaping
int64_t ntensors_reshaped{0};
if (proposed_shape[0] >= 0) {
ntensors_reshaped = proposed_shape[0];
}
else if (proposed_shape[0] == -1) {
ntensors_reshaped = ntensors;
}
else {
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
}
int64_t ntensors_reshaped = proposed_shape[0];
TORCH_CHECK(
ntensors == ntensors_reshaped,
"for now reshape cannot change the implicit batch dimension");
"reshape: For now nested reshape cannot change or infer the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
// reshaping underlying tensor dimensions does not change offset

View File

@ -1191,11 +1191,11 @@ class TestNestedTensorDeviceType(TestCase):
"empty nested tensor cannot be reshaped",
lambda: nt_empty.view(-1)
)
# error case: invalid proposed shape for underlying tensors
# error case: -1 for batch size
self.assertRaisesRegex(
RuntimeError,
r"invalid shape dimension -2",
lambda: nt.view(-2, 2, 3)
r"view: For now nested view cannot change or infer the implicit batch dimension",
lambda: nt.view(-1, 2, 3)
)
self.assertRaisesRegex(
RuntimeError,
@ -1207,9 +1207,10 @@ class TestNestedTensorDeviceType(TestCase):
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested.nested_tensor([x0, x1])
pt = torch.nested.to_padded_tensor(nt, 0.0)
# error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
r"for now view cannot change the implicit batch dimension",
r"For now nested view cannot change or infer the implicit batch dimension",
lambda: nt.transpose(-1, -2).view(40, -1)
)
# inherit only the ragged dimension
@ -1219,10 +1220,15 @@ class TestNestedTensorDeviceType(TestCase):
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.view(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
# also inherit regular dimension
nt2 = nt1.view(2, -1, -1, 2, 2)
pt2 = pt1.view(2, -1, 5, 2, 2)
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
# more than one -1 (even for "old" dims), should fail
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
# but we ban "inherit old behavior" for >1 dimension
self.assertRaisesRegex(
RuntimeError,
r"only one dimension can be inferred",
lambda: nt1.view(2, -1, -1, 2, 2)
)
@dtypes(torch.float, torch.float16, torch.double)
def test_view_inference_mode_interaction(self, device, dtype):
@ -1259,11 +1265,11 @@ class TestNestedTensorDeviceType(TestCase):
"empty nested tensor cannot be reshaped",
lambda: nt_empty.reshape(-1)
)
# error case: invalid proposed shape for underlying tensors
# error case: -1 for batch size
self.assertRaisesRegex(
RuntimeError,
r"invalid shape dimension -2",
lambda: nt.reshape(-2, 2, 3)
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
lambda: nt.reshape(-1, 2, 3)
)
self.assertRaisesRegex(
RuntimeError,
@ -1275,9 +1281,10 @@ class TestNestedTensorDeviceType(TestCase):
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20)
pt = torch.nested.to_padded_tensor(nt, 0.0)
# error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
r"for now reshape cannot change the implicit batch dimension",
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
lambda: nt.transpose(-1, -2).reshape(40, -1)
)
# inherit only the ragged dimension
@ -1287,10 +1294,15 @@ class TestNestedTensorDeviceType(TestCase):
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.reshape(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
# also inherit regular dimension
nt2 = nt1.reshape(2, -1, -1, 2, 2)
pt2 = pt1.reshape(2, -1, 5, 2, 2)
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
# more than one -1 (even for "old" dims), should fail
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
# but we ban "inherit old behavior" for >1 dimension
self.assertRaisesRegex(
RuntimeError,
r"only one dimension can be inferred",
lambda: nt1.reshape(2, -1, -1, 2, 2)
)
@parametrize("input_dim", [3, 4])
def test_scaled_dot_product_attention(self, device, input_dim):