mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] dynamic_shapes uses DYNAMIC (#153065)
Although Dim.AUTO covers the cases that a user sets more axes to be dynamic than the model actually needs, it silently falls back to STATIC when DYNAMIC fails. This increases the difficulty of debugging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153065 Approved by: https://github.com/justinchuby
This commit is contained in:
parent
a2891cba2f
commit
773a91c775
|
|
@ -558,7 +558,7 @@ class TestDynamicShapes(common_utils.TestCase):
|
|||
expected_dynamic_shapes = {
|
||||
"input_x": [
|
||||
{
|
||||
0: torch.export.Dim.AUTO,
|
||||
0: torch.export.Dim.DYNAMIC,
|
||||
1: torch.export.Dim.STATIC,
|
||||
},
|
||||
{
|
||||
|
|
@ -566,7 +566,7 @@ class TestDynamicShapes(common_utils.TestCase):
|
|||
1: dimx,
|
||||
},
|
||||
],
|
||||
"input_b": {2: torch.export.Dim.AUTO},
|
||||
"input_b": {2: torch.export.Dim.DYNAMIC},
|
||||
}
|
||||
dynamic_shapes_with_export_dim, need_axis_mapping = (
|
||||
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
|
||||
|
|
@ -598,7 +598,7 @@ class TestDynamicShapes(common_utils.TestCase):
|
|||
},
|
||||
{
|
||||
0: torch.export.Dim.AUTO,
|
||||
1: torch.export.Dim.AUTO,
|
||||
1: torch.export.Dim.DYNAMIC,
|
||||
},
|
||||
],
|
||||
{2: torch.export.Dim.STATIC},
|
||||
|
|
|
|||
|
|
@ -27,17 +27,17 @@ def from_dynamic_axes_to_dynamic_shapes(
|
|||
input_names: Sequence[str] | None = None,
|
||||
) -> tuple[dict[str, Any | None] | None, tuple[Any, ...], dict[str, Any] | None]:
|
||||
"""
|
||||
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.AUTO``.
|
||||
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.DYNAMIC``.
|
||||
|
||||
dynamic_axes examples:
|
||||
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
|
||||
(2) dynamic_axes = {"x": [0], "y": [1]}
|
||||
|
||||
these will be converted to dynamic_shapes respectively:
|
||||
(1) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
|
||||
(2) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
|
||||
(1) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
|
||||
(2) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
|
||||
|
||||
Detail on Dim.AUTO: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
|
||||
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/pull/128371
|
||||
# 1. The function does not need to provide dynamic_shapes to torch.export.export
|
||||
|
|
@ -52,7 +52,7 @@ def from_dynamic_axes_to_dynamic_shapes(
|
|||
|
||||
dynamic_shapes: dict[str, Any | None] = {}
|
||||
for input_name, axes in dynamic_axes.items():
|
||||
# NOTE: torch.export.Dim.AUTO does its best to infer the min and max values
|
||||
# NOTE: torch.export.Dim.DYNAMIC does its best to infer the min and max values
|
||||
# from the model, but it's not guaranteed to be dynamic.
|
||||
if input_name in output_names:
|
||||
# output names are not needed for dynamic_shapes
|
||||
|
|
@ -63,14 +63,14 @@ def from_dynamic_axes_to_dynamic_shapes(
|
|||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||
)
|
||||
dynamic_shapes[input_name] = {
|
||||
k: torch.export.Dim.AUTO for k, _ in axes.items()
|
||||
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
|
||||
}
|
||||
elif isinstance(axes, list):
|
||||
if any(not isinstance(k, int) for k in axes):
|
||||
raise ValueError(
|
||||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||
)
|
||||
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}
|
||||
dynamic_shapes[input_name] = {k: torch.export.Dim.DYNAMIC for k in axes}
|
||||
elif axes is None:
|
||||
dynamic_shapes[input_name] = None
|
||||
else:
|
||||
|
|
@ -185,10 +185,10 @@ def convert_str_to_export_dim(
|
|||
# 1. If there is no string in dynamic_shapes, we do not touch dynamic_shapes
|
||||
if dynamic_shapes is None or not _any_str_or_dim_in_dynamic_shapes(dynamic_shapes):
|
||||
return dynamic_shapes, False
|
||||
# 2. Convert "name" to Dim.AUTO with flattening and identify if there is any string
|
||||
# to be replaced with Dim.AUTO, and then unflatten it back to the original structure.
|
||||
# 2. Convert "name" to Dim.DYNAMIC with flattening and identify if there is any string
|
||||
# to be replaced with Dim.DYNAMIC, and then unflatten it back to the original structure.
|
||||
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
|
||||
# to {"y": {0: Dim.AUTO}, "x": {1: Dim.AUTO}}
|
||||
# to {"y": {0: Dim.DYNAMIC}, "x": {1: Dim.DYNAMIC}}
|
||||
dynamic_shapes_with_export_dim: list[
|
||||
list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None
|
||||
] = []
|
||||
|
|
@ -202,7 +202,7 @@ def convert_str_to_export_dim(
|
|||
converted_axes_dict: dict[int, Dim | _DimHint | None] = {}
|
||||
for axis, dim in axes.items():
|
||||
if isinstance(dim, str):
|
||||
converted_axes_dict[axis] = torch.export.Dim.AUTO
|
||||
converted_axes_dict[axis] = torch.export.Dim.DYNAMIC
|
||||
else:
|
||||
converted_axes_dict[axis] = dim
|
||||
dynamic_shapes_with_export_dim.append(converted_axes_dict)
|
||||
|
|
@ -210,7 +210,7 @@ def convert_str_to_export_dim(
|
|||
converted_axes_list: list[Dim | _DimHint | None] = []
|
||||
for dim in axes:
|
||||
if isinstance(dim, str):
|
||||
converted_axes_list.append(torch.export.Dim.AUTO)
|
||||
converted_axes_list.append(torch.export.Dim.DYNAMIC)
|
||||
else:
|
||||
converted_axes_list.append(dim)
|
||||
dynamic_shapes_with_export_dim.append(converted_axes_list)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user