mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow only one -1 in nested view/reshape (#85691)
### this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension Behavior before this PR: 1. `-1` allowed for implicit batch dimension 2. multiple `-1`s allowed for pre-existing dimensions 3. for new dimensions, `-1` is not allowed it is worth noting that for the most part 3 is basically unreachable because assuming a nested tensor has at least 1 ragged dimension, you would expect at least one -1 to be in the proposed shape for the pre-existing dimensions Behavior after this PR: 1. batch dimension **must be specified** 2. **only one** `-1` allowed for pre-existing dimensions **this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension** 3. unchanged Pull Request resolved: https://github.com/pytorch/pytorch/pull/85691 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
1418a663b1
commit
01add6e288
|
|
@ -997,6 +997,8 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
|
|||
& stride = strides[itensor];
|
||||
// compute reshaped size
|
||||
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
|
||||
// only allow one pre-existing dimension to have proposed shape == -1
|
||||
int64_t infer_index_old = -1;
|
||||
// some negative sizes remain to be infered
|
||||
if (ndims_underlying < ndims_underlying_reshaped) {
|
||||
int64_t numel = 1, numel_reshaped = 1;
|
||||
|
|
@ -1005,7 +1007,9 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
|
|||
int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
|
||||
if (size_reshaped == -1) {
|
||||
TORCH_CHECK(infer_index_old == -1, "only one dimension can be inferred");
|
||||
size_reshaped = size[idim];
|
||||
infer_index_old = idim;
|
||||
}
|
||||
numel *= size[idim];
|
||||
numel_reshaped *= size_reshaped;
|
||||
|
|
@ -1087,7 +1091,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
|
|||
// Note [Special size rule for nested tensor]
|
||||
// Instead of infering size, -1 means "inherit the old size", so:
|
||||
// * negative size is legal for a ragged dimension
|
||||
// * multiple sizes can be -1
|
||||
// * however, we only allow one -1
|
||||
// In principle we could still infer a dimension,
|
||||
// we are designing a better semantics to include both inheritance and inference
|
||||
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
|
|
@ -1101,19 +1105,10 @@ Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
|||
ntensors > 0,
|
||||
"empty nested tensor cannot be reshaped");
|
||||
// basic information after reshaping
|
||||
int64_t ntensors_reshaped;
|
||||
if (proposed_shape[0] >= 0) {
|
||||
ntensors_reshaped = proposed_shape[0];
|
||||
}
|
||||
else if (proposed_shape[0] == -1) {
|
||||
ntensors_reshaped = ntensors;
|
||||
}
|
||||
else {
|
||||
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
|
||||
}
|
||||
int64_t ntensors_reshaped = proposed_shape[0];
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors_reshaped,
|
||||
"for now view cannot change the implicit batch dimension");
|
||||
"view: For now nested view cannot change or infer the implicit batch dimension");
|
||||
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||
strides = NestedTensor_get_strides(self_ptr);
|
||||
// reshaping underlying tensor dimensions does not change offset
|
||||
|
|
@ -1197,19 +1192,10 @@ Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
|||
ntensors > 0,
|
||||
"empty nested tensor cannot be reshaped");
|
||||
// basic information after reshaping
|
||||
int64_t ntensors_reshaped{0};
|
||||
if (proposed_shape[0] >= 0) {
|
||||
ntensors_reshaped = proposed_shape[0];
|
||||
}
|
||||
else if (proposed_shape[0] == -1) {
|
||||
ntensors_reshaped = ntensors;
|
||||
}
|
||||
else {
|
||||
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
|
||||
}
|
||||
int64_t ntensors_reshaped = proposed_shape[0];
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors_reshaped,
|
||||
"for now reshape cannot change the implicit batch dimension");
|
||||
"reshape: For now nested reshape cannot change or infer the implicit batch dimension");
|
||||
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||
strides = NestedTensor_get_strides(self_ptr);
|
||||
// reshaping underlying tensor dimensions does not change offset
|
||||
|
|
|
|||
|
|
@ -1191,11 +1191,11 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
"empty nested tensor cannot be reshaped",
|
||||
lambda: nt_empty.view(-1)
|
||||
)
|
||||
# error case: invalid proposed shape for underlying tensors
|
||||
# error case: -1 for batch size
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"invalid shape dimension -2",
|
||||
lambda: nt.view(-2, 2, 3)
|
||||
r"view: For now nested view cannot change or infer the implicit batch dimension",
|
||||
lambda: nt.view(-1, 2, 3)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
|
|
@ -1207,9 +1207,10 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
||||
nt = torch.nested.nested_tensor([x0, x1])
|
||||
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
||||
# error case, trying to reshape batch dim to a legit shape
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"for now view cannot change the implicit batch dimension",
|
||||
r"For now nested view cannot change or infer the implicit batch dimension",
|
||||
lambda: nt.transpose(-1, -2).view(40, -1)
|
||||
)
|
||||
# inherit only the ragged dimension
|
||||
|
|
@ -1219,10 +1220,15 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
||||
pt1 = pt.view(2, -1, 5, 4)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
||||
# also inherit regular dimension
|
||||
nt2 = nt1.view(2, -1, -1, 2, 2)
|
||||
pt2 = pt1.view(2, -1, 5, 2, 2)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
|
||||
|
||||
# more than one -1 (even for "old" dims), should fail
|
||||
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
||||
# but we ban "inherit old behavior" for >1 dimension
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"only one dimension can be inferred",
|
||||
lambda: nt1.view(2, -1, -1, 2, 2)
|
||||
)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_view_inference_mode_interaction(self, device, dtype):
|
||||
|
|
@ -1259,11 +1265,11 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
"empty nested tensor cannot be reshaped",
|
||||
lambda: nt_empty.reshape(-1)
|
||||
)
|
||||
# error case: invalid proposed shape for underlying tensors
|
||||
# error case: -1 for batch size
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"invalid shape dimension -2",
|
||||
lambda: nt.reshape(-2, 2, 3)
|
||||
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
||||
lambda: nt.reshape(-1, 2, 3)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
|
|
@ -1275,9 +1281,10 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
||||
nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20)
|
||||
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
||||
# error case, trying to reshape batch dim to a legit shape
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"for now reshape cannot change the implicit batch dimension",
|
||||
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
||||
lambda: nt.transpose(-1, -2).reshape(40, -1)
|
||||
)
|
||||
# inherit only the ragged dimension
|
||||
|
|
@ -1287,10 +1294,15 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
||||
pt1 = pt.reshape(2, -1, 5, 4)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
||||
# also inherit regular dimension
|
||||
nt2 = nt1.reshape(2, -1, -1, 2, 2)
|
||||
pt2 = pt1.reshape(2, -1, 5, 2, 2)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
|
||||
|
||||
# more than one -1 (even for "old" dims), should fail
|
||||
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
||||
# but we ban "inherit old behavior" for >1 dimension
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"only one dimension can be inferred",
|
||||
lambda: nt1.reshape(2, -1, -1, 2, 2)
|
||||
)
|
||||
|
||||
@parametrize("input_dim", [3, 4])
|
||||
def test_scaled_dot_product_attention(self, device, input_dim):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user