mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[export] allow partially specifying keys for dynamic shapes dict spec (#151597)"
This reverts commitc8240e3492. Reverted https://github.com/pytorch/pytorch/pull/151597 on behalf of https://github.com/clee2000 due to broke some export test export/test_converter.py::TestConverter::test_aten_len [GH job link](https://github.com/pytorch/pytorch/actions/runs/14538615968/job/40792673415) [HUD commit link](c8240e3492), bad TD ([comment](https://github.com/pytorch/pytorch/pull/151597#issuecomment-2816127271))
This commit is contained in:
parent
f20a266512
commit
1b267a58a1
|
|
@ -4024,10 +4024,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
re.escape(
|
||||
"`dynamic_shapes` was specified as a dict, but found top-level keys ['k'] "
|
||||
"that weren't present in the arg names of `inputs`: ['x']. "
|
||||
"Since here `inputs` is a list/tuple enclosing a single dict, maybe you "
|
||||
"just forgot to enclose `dynamic_shapes` in a list/tuple?"
|
||||
"When `dynamic_shapes` is specified as a dict, its top-level keys "
|
||||
"must be the arg names ['x'] of `inputs`, but here they are ['k']. "
|
||||
"Since here `inputs` is a list/tuple enclosing a single dict, "
|
||||
"maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
|
||||
),
|
||||
):
|
||||
export(M(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
|
@ -4080,8 +4080,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
re.escape(
|
||||
"`dynamic_shapes` was specified as a dict, but found top-level keys ['k'] "
|
||||
"that weren't present in the arg names of `inputs`: ['x']. "
|
||||
"When `dynamic_shapes` is specified as a dict, its top-level keys "
|
||||
"must be the arg names ['x'] of `inputs`, but here they are ['x', 'k']. "
|
||||
"Alternatively, you could also ignore arg names entirely "
|
||||
"and specify `dynamic_shapes` as a list/tuple matching `inputs`."
|
||||
),
|
||||
|
|
@ -4135,17 +4135,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
):
|
||||
export(O(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
class P(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
return x + 2, y + z
|
||||
|
||||
inputs = (torch.randn(4), torch.randn(6), torch.randn(6))
|
||||
dynamic_shapes = {"x": (Dim.AUTO,)}
|
||||
if not is_retracebility_test(
|
||||
self._testMethodName
|
||||
): # retraceability test forces dict -> tuple conversion
|
||||
export(P(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_unbacked_bindings_for_divisible_u_symint(self):
|
||||
from torch._export.utils import _get_shape_env_from_gm
|
||||
from torch.utils._sympy.symbol import prefix_str, symbol_is_type, SymT
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
|
|
@ -1166,7 +1165,6 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
|
|||
if isinstance(dynamic_shapes, torch.export.ShapesCollection):
|
||||
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
|
||||
|
||||
dynamic_shapes = copy.deepcopy(dynamic_shapes)
|
||||
return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -896,12 +896,11 @@ def _check_dynamic_shapes(
|
|||
if isinstance(dynamic_shapes, dict):
|
||||
got_keys = list(dynamic_shapes.keys())
|
||||
expected_arg_names = list(combined_args.keys())
|
||||
extra_keys = set(got_keys).difference(expected_arg_names)
|
||||
if extra_keys:
|
||||
if sorted(got_keys) != sorted(expected_arg_names):
|
||||
msg = (
|
||||
f"`dynamic_shapes` was specified as a dict, but found top-level keys "
|
||||
f"{list(extra_keys)} that weren't present in the arg names of `inputs`: "
|
||||
f"{expected_arg_names}. "
|
||||
f"When `dynamic_shapes` is specified as a dict, its top-level keys "
|
||||
f"must be the arg names {expected_arg_names} of `inputs`, but "
|
||||
f"here they are {got_keys}. "
|
||||
)
|
||||
if (
|
||||
len(combined_args) == 1
|
||||
|
|
@ -920,9 +919,6 @@ def _check_dynamic_shapes(
|
|||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
|
||||
)
|
||||
# populate unspecified keys with None
|
||||
for unspec_key in set(combined_args.keys()).difference(dynamic_shapes.keys()):
|
||||
dynamic_shapes[unspec_key] = None
|
||||
|
||||
def check_shape(path, t, dynamic_shape):
|
||||
if isinstance(t, torch.Tensor):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user