mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] aten.constant_pad_nd meta impl (#152129)
We know the output shape, and we know this always produces a clone. Avoids data-dependent errors from the decomposition. along with https://github.com/pytorch/pytorch/pull/150483, should fix https://github.com/pytorch/pytorch/issues/123855 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152129 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
53bf174626
commit
701c0848b8
|
|
@ -4360,6 +4360,26 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
):
|
||||
_ = export(M(), (torch.tensor([2, 3, 5]),))
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
||||
def test_unbacked_pad(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, xs, pad):
|
||||
u0, u1, u2 = xs.tolist()
|
||||
x = torch.ones(u0, u1, u2)
|
||||
pl0, pr0, pl1, pr1 = pad.tolist()
|
||||
return torch.nn.functional.pad(x, (pl0, pr0, pl1, pr1))
|
||||
|
||||
x = torch.tensor([64, 64, 64])
|
||||
pad = torch.tensor([8, -8, 4, 0])
|
||||
m = Foo()
|
||||
ep = export(m, (x, pad))
|
||||
self.assertEqual(ep.module()(x, pad).shape, m(x, pad).shape)
|
||||
|
||||
# don't guard on negative/positive pad values
|
||||
pad2 = torch.tensor([-5, 9, 0, 8])
|
||||
self.assertEqual(ep.module()(x, pad2).shape, m(x, pad2).shape)
|
||||
|
||||
def test_suggested_fixes_for_data_dependent_errors_basic(self):
|
||||
# suggested fixes for data-dependent errors only work in non-strict mode
|
||||
strict = False
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from torch._prims_common import (
|
|||
IntLike,
|
||||
make_contiguous_strides_for,
|
||||
Number,
|
||||
suggest_memory_format,
|
||||
TensorLike,
|
||||
)
|
||||
from torch._prims_common.wrappers import (
|
||||
|
|
@ -7324,6 +7325,48 @@ def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
|
|||
return res
|
||||
|
||||
|
||||
@register_meta(aten.constant_pad_nd)
|
||||
@out_wrapper()
|
||||
def _constant_pad_nd_meta(input, pad, value=0):
|
||||
# same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd()
|
||||
torch._check(
|
||||
len(pad) % 2 == 0,
|
||||
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
|
||||
)
|
||||
|
||||
input_sizes = input.shape
|
||||
l_inp = len(input_sizes)
|
||||
l_pad = len(pad) // 2
|
||||
l_diff = l_inp - l_pad
|
||||
|
||||
torch._check(
|
||||
l_inp >= l_pad,
|
||||
lambda: "Length of pad should be no more than twice the number of "
|
||||
f"dimensions of the input. Pad length is {len(pad)} while the input has "
|
||||
f"{l_inp} dimensions.",
|
||||
)
|
||||
|
||||
new_shape = list(input_sizes[:l_diff])
|
||||
for i in range(l_pad):
|
||||
pad_idx = len(pad) - ((i + 1) * 2)
|
||||
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
|
||||
torch._check(
|
||||
new_dim >= 0,
|
||||
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
|
||||
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
|
||||
f"which is invalid. Check dimension {l_diff + i} of your input.",
|
||||
)
|
||||
new_shape.append(new_dim)
|
||||
|
||||
return torch.empty(
|
||||
new_shape,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
requires_grad=input.requires_grad,
|
||||
memory_format=suggest_memory_format(input),
|
||||
)
|
||||
|
||||
|
||||
@register_meta(aten.embedding)
|
||||
@out_wrapper()
|
||||
def embedding(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user