[dynamic shapes] aten.constant_pad_nd meta impl (#152129)

We know the output shape, and we know this always produces a clone. Avoids data-dependent errors from the decomposition.

along with https://github.com/pytorch/pytorch/pull/150483, should fix https://github.com/pytorch/pytorch/issues/123855
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152129
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-05-01 08:32:07 +00:00 committed by PyTorch MergeBot
parent 53bf174626
commit 701c0848b8
2 changed files with 63 additions and 0 deletions

View File

@ -4360,6 +4360,26 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
_ = export(M(), (torch.tensor([2, 3, 5]),))
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_unbacked_pad(self):
class Foo(torch.nn.Module):
def forward(self, xs, pad):
u0, u1, u2 = xs.tolist()
x = torch.ones(u0, u1, u2)
pl0, pr0, pl1, pr1 = pad.tolist()
return torch.nn.functional.pad(x, (pl0, pr0, pl1, pr1))
x = torch.tensor([64, 64, 64])
pad = torch.tensor([8, -8, 4, 0])
m = Foo()
ep = export(m, (x, pad))
self.assertEqual(ep.module()(x, pad).shape, m(x, pad).shape)
# don't guard on negative/positive pad values
pad2 = torch.tensor([-5, 9, 0, 8])
self.assertEqual(ep.module()(x, pad2).shape, m(x, pad2).shape)
def test_suggested_fixes_for_data_dependent_errors_basic(self):
# suggested fixes for data-dependent errors only work in non-strict mode
strict = False

View File

@ -27,6 +27,7 @@ from torch._prims_common import (
IntLike,
make_contiguous_strides_for,
Number,
suggest_memory_format,
TensorLike,
)
from torch._prims_common.wrappers import (
@ -7324,6 +7325,48 @@ def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
return res
@register_meta(aten.constant_pad_nd)
@out_wrapper()
def _constant_pad_nd_meta(input, pad, value=0):
# same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd()
torch._check(
len(pad) % 2 == 0,
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
)
input_sizes = input.shape
l_inp = len(input_sizes)
l_pad = len(pad) // 2
l_diff = l_inp - l_pad
torch._check(
l_inp >= l_pad,
lambda: "Length of pad should be no more than twice the number of "
f"dimensions of the input. Pad length is {len(pad)} while the input has "
f"{l_inp} dimensions.",
)
new_shape = list(input_sizes[:l_diff])
for i in range(l_pad):
pad_idx = len(pad) - ((i + 1) * 2)
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
torch._check(
new_dim >= 0,
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
f"which is invalid. Check dimension {l_diff + i} of your input.",
)
new_shape.append(new_dim)
return torch.empty(
new_shape,
dtype=input.dtype,
device=input.device,
requires_grad=input.requires_grad,
memory_format=suggest_memory_format(input),
)
@register_meta(aten.embedding)
@out_wrapper()
def embedding(