[ONNX] dynamic_shapes uses DYNAMIC (#153065)

Although Dim.AUTO covers the cases that a user sets more axes to be dynamic than the model actually needs, it silently falls back to STATIC when DYNAMIC fails. This increases the difficulty of debugging.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153065
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang 2025-05-07 21:48:38 +00:00 committed by PyTorch MergeBot
parent a2891cba2f
commit 773a91c775
2 changed files with 15 additions and 15 deletions

View File

@ -558,7 +558,7 @@ class TestDynamicShapes(common_utils.TestCase):
expected_dynamic_shapes = {
"input_x": [
{
0: torch.export.Dim.AUTO,
0: torch.export.Dim.DYNAMIC,
1: torch.export.Dim.STATIC,
},
{
@ -566,7 +566,7 @@ class TestDynamicShapes(common_utils.TestCase):
1: dimx,
},
],
"input_b": {2: torch.export.Dim.AUTO},
"input_b": {2: torch.export.Dim.DYNAMIC},
}
dynamic_shapes_with_export_dim, need_axis_mapping = (
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
@ -598,7 +598,7 @@ class TestDynamicShapes(common_utils.TestCase):
},
{
0: torch.export.Dim.AUTO,
1: torch.export.Dim.AUTO,
1: torch.export.Dim.DYNAMIC,
},
],
{2: torch.export.Dim.STATIC},

View File

@ -27,17 +27,17 @@ def from_dynamic_axes_to_dynamic_shapes(
input_names: Sequence[str] | None = None,
) -> tuple[dict[str, Any | None] | None, tuple[Any, ...], dict[str, Any] | None]:
"""
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.AUTO``.
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.DYNAMIC``.
dynamic_axes examples:
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
(2) dynamic_axes = {"x": [0], "y": [1]}
these will be converted to dynamic_shapes respectively:
(1) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
(2) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
(1) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
(2) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
Detail on Dim.AUTO: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
"""
# https://github.com/pytorch/pytorch/pull/128371
# 1. The function does not need to provide dynamic_shapes to torch.export.export
@ -52,7 +52,7 @@ def from_dynamic_axes_to_dynamic_shapes(
dynamic_shapes: dict[str, Any | None] = {}
for input_name, axes in dynamic_axes.items():
# NOTE: torch.export.Dim.AUTO does its best to infer the min and max values
# NOTE: torch.export.Dim.DYNAMIC does its best to infer the min and max values
# from the model, but it's not guaranteed to be dynamic.
if input_name in output_names:
# output names are not needed for dynamic_shapes
@ -63,14 +63,14 @@ def from_dynamic_axes_to_dynamic_shapes(
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
)
dynamic_shapes[input_name] = {
k: torch.export.Dim.AUTO for k, _ in axes.items()
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
}
elif isinstance(axes, list):
if any(not isinstance(k, int) for k in axes):
raise ValueError(
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
)
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}
dynamic_shapes[input_name] = {k: torch.export.Dim.DYNAMIC for k in axes}
elif axes is None:
dynamic_shapes[input_name] = None
else:
@ -185,10 +185,10 @@ def convert_str_to_export_dim(
# 1. If there is no string in dynamic_shapes, we do not touch dynamic_shapes
if dynamic_shapes is None or not _any_str_or_dim_in_dynamic_shapes(dynamic_shapes):
return dynamic_shapes, False
# 2. Convert "name" to Dim.AUTO with flattening and identify if there is any string
# to be replaced with Dim.AUTO, and then unflatten it back to the original structure.
# 2. Convert "name" to Dim.DYNAMIC with flattening and identify if there is any string
# to be replaced with Dim.DYNAMIC, and then unflatten it back to the original structure.
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
# to {"y": {0: Dim.AUTO}, "x": {1: Dim.AUTO}}
# to {"y": {0: Dim.DYNAMIC}, "x": {1: Dim.DYNAMIC}}
dynamic_shapes_with_export_dim: list[
list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None
] = []
@ -202,7 +202,7 @@ def convert_str_to_export_dim(
converted_axes_dict: dict[int, Dim | _DimHint | None] = {}
for axis, dim in axes.items():
if isinstance(dim, str):
converted_axes_dict[axis] = torch.export.Dim.AUTO
converted_axes_dict[axis] = torch.export.Dim.DYNAMIC
else:
converted_axes_dict[axis] = dim
dynamic_shapes_with_export_dim.append(converted_axes_dict)
@ -210,7 +210,7 @@ def convert_str_to_export_dim(
converted_axes_list: list[Dim | _DimHint | None] = []
for dim in axes:
if isinstance(dim, str):
converted_axes_list.append(torch.export.Dim.AUTO)
converted_axes_list.append(torch.export.Dim.DYNAMIC)
else:
converted_axes_list.append(dim)
dynamic_shapes_with_export_dim.append(converted_axes_list)