Revert "[export] allow partially specifying keys for dynamic shapes dict spec (#151597)"

This reverts commit c8240e3492.

Reverted https://github.com/pytorch/pytorch/pull/151597 on behalf of https://github.com/clee2000 due to broke some export test export/test_converter.py::TestConverter::test_aten_len [GH job link](https://github.com/pytorch/pytorch/actions/runs/14538615968/job/40792673415) [HUD commit link](c8240e3492), bad TD ([comment](https://github.com/pytorch/pytorch/pull/151597#issuecomment-2816127271))
This commit is contained in:
PyTorch MergeBot 2025-04-18 20:17:44 +00:00
parent f20a266512
commit 1b267a58a1
3 changed files with 10 additions and 27 deletions

View File

@ -4024,10 +4024,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
re.escape(
"`dynamic_shapes` was specified as a dict, but found top-level keys ['k'] "
"that weren't present in the arg names of `inputs`: ['x']. "
"Since here `inputs` is a list/tuple enclosing a single dict, maybe you "
"just forgot to enclose `dynamic_shapes` in a list/tuple?"
"When `dynamic_shapes` is specified as a dict, its top-level keys "
"must be the arg names ['x'] of `inputs`, but here they are ['k']. "
"Since here `inputs` is a list/tuple enclosing a single dict, "
"maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
),
):
export(M(), inputs, dynamic_shapes=dynamic_shapes)
@ -4080,8 +4080,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
re.escape(
"`dynamic_shapes` was specified as a dict, but found top-level keys ['k'] "
"that weren't present in the arg names of `inputs`: ['x']. "
"When `dynamic_shapes` is specified as a dict, its top-level keys "
"must be the arg names ['x'] of `inputs`, but here they are ['x', 'k']. "
"Alternatively, you could also ignore arg names entirely "
"and specify `dynamic_shapes` as a list/tuple matching `inputs`."
),
@ -4135,17 +4135,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
export(O(), inputs, dynamic_shapes=dynamic_shapes)
class P(torch.nn.Module):
def forward(self, x, y, z):
return x + 2, y + z
inputs = (torch.randn(4), torch.randn(6), torch.randn(6))
dynamic_shapes = {"x": (Dim.AUTO,)}
if not is_retracebility_test(
self._testMethodName
): # retraceability test forces dict -> tuple conversion
export(P(), inputs, dynamic_shapes=dynamic_shapes)
def test_unbacked_bindings_for_divisible_u_symint(self):
from torch._export.utils import _get_shape_env_from_gm
from torch.utils._sympy.symbol import prefix_str, symbol_is_type, SymT

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import copy
import dataclasses
import functools
import inspect
@ -1166,7 +1165,6 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
if isinstance(dynamic_shapes, torch.export.ShapesCollection):
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
dynamic_shapes = copy.deepcopy(dynamic_shapes)
return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs

View File

@ -896,12 +896,11 @@ def _check_dynamic_shapes(
if isinstance(dynamic_shapes, dict):
got_keys = list(dynamic_shapes.keys())
expected_arg_names = list(combined_args.keys())
extra_keys = set(got_keys).difference(expected_arg_names)
if extra_keys:
if sorted(got_keys) != sorted(expected_arg_names):
msg = (
f"`dynamic_shapes` was specified as a dict, but found top-level keys "
f"{list(extra_keys)} that weren't present in the arg names of `inputs`: "
f"{expected_arg_names}. "
f"When `dynamic_shapes` is specified as a dict, its top-level keys "
f"must be the arg names {expected_arg_names} of `inputs`, but "
f"here they are {got_keys}. "
)
if (
len(combined_args) == 1
@ -920,9 +919,6 @@ def _check_dynamic_shapes(
raise UserError(
UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
)
# populate unspecified keys with None
for unspec_key in set(combined_args.keys()).difference(dynamic_shapes.keys()):
dynamic_shapes[unspec_key] = None
def check_shape(path, t, dynamic_shape):
if isinstance(t, torch.Tensor):