[ONNX] Automatically convert dynamic_axes to dynamic_shapes with torch.export.Dim.AUTO (#143158)

With https://github.com/pytorch/pytorch/pull/133620 introducing Dim.AUTO, we can now automatically convert dynamic_axes to dynamic_shapes without specifying min and max. However, exporting still could be crashed when there are same specs shared between inputs and there is no guarantee that the axes will be dynamic (see PR description).

~~Therefore, a~~ follow-up PR should create a post-processing ONNX side pass to ~~enable the missed dynamic axes~~ rename the dynamic shapes (s0,  s1, ...) to dynamic_axes (user setting names).

This PR does:
(1) Apply torch.export.Dim.AUTO to dynamic_axes when dynamic_shapes is not provided.
(2) Convert args/kwargs to tuple inputs, which follows the generated dynamic_shapes format to avoid errors during torch.export.export.
(3) Avoid KeyError in _rename_dynamic_shapes_with_model_inputs funtion.
(4) Add real world case of a HF model with kv_cache to test on ONNX exporter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143158
Approved by: https://github.com/xadupre, https://github.com/shubhambhokare1
This commit is contained in:
titaiwangms 2024-12-18 23:48:58 +00:00 committed by PyTorch MergeBot
parent 15a7a0c37e
commit b23f11c529
3 changed files with 223 additions and 22 deletions

View File

@ -314,8 +314,46 @@ class TestPyTreeDynamicAxesShapes(common_utils.TestCase):
) )
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes) self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)
def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_dict_of_mixed_structure(
self,
):
inputs = {
"w": torch.randn(1, 2, 3),
"x": ({"x0": torch.randn(1, 2, 3)}, {"x1": torch.randn(1, 2, 3)}),
"y": (torch.randn(1, 2, 3), torch.randn(1, 2, 3)),
"z": [torch.randn(1, 2, 3), torch.randn(1, 2, 3)],
}
w_dim_0 = torch.export.Dim("w_dim_0")
x0_dim_1 = torch.export.Dim("x0_dim_1")
x0_dim_2 = torch.export.Dim("x0_dim_2")
x1_dim_1 = torch.export.Dim("x1_dim_1")
y0_dim_0 = torch.export.Dim("y0_dim_0")
y0_dim_1 = torch.export.Dim("y0_dim_1")
y1_dim_2 = torch.export.Dim("y1_dim_2")
z0_dim_2 = torch.export.Dim("z0_dim_2")
z1_dim_1 = torch.export.Dim("z1_dim_1")
dynamic_shapes = {
"w": {0: w_dim_0},
"x0": {1: x0_dim_1, 2: x0_dim_2},
"x1": {1: x1_dim_1},
"y0": {0: y0_dim_0, 1: y0_dim_1},
"y1": {2: y1_dim_2},
"z0": {2: z0_dim_2},
"z1": {1: z1_dim_1},
}
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
inputs, dynamic_shapes
)
expected_dynamic_shapes = {
"w": {0: w_dim_0},
"x": ({"x0": {1: x0_dim_1, 2: x0_dim_2}}, {"x1": {1: x1_dim_1}}),
"y": ({0: y0_dim_0, 1: y0_dim_1}, {2: y1_dim_2}),
"z": [{2: z0_dim_2}, {1: z1_dim_1}],
}
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)
@common_utils.parametrize( @common_utils.parametrize(
"model, args, kwargs,input_names, output_names, dynamic_axes, expected_dynamic_shapes", "model, args, kwargs, input_names, output_names, dynamic_axes, expected_dynamic_shapes",
[ [
# llama-3.2-1B-Instruct (trimmed) # llama-3.2-1B-Instruct (trimmed)
( (
@ -436,7 +474,7 @@ class TestPyTreeDynamicAxesShapes(common_utils.TestCase):
dynamic_axes, dynamic_axes,
expected_dynamic_shapes, expected_dynamic_shapes,
): ):
dynamic_shapes = _compat._from_dynamic_axes_to_dynamic_shapes( dynamic_shapes, _, _ = _compat._from_dynamic_axes_to_dynamic_shapes(
model, model,
args, args,
kwargs, kwargs,

View File

@ -0,0 +1,155 @@
# Owner(s): ["module: onnx"]
"""Unit LLM tests for the onnx dynamo exporter."""
from __future__ import annotations
from typing import Any
import transformers
import torch
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.testing._internal import common_utils
class DynamoExporterTest(common_utils.TestCase):
def test_onnx_export_huggingface_llm_models_with_kv_cache(self):
model, kwargs, dynamic_axes, input_names, output_names = (
_prepare_llm_model_gptj_to_test()
)
onnx_program = torch.onnx.export(
model,
kwargs=kwargs,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=True,
)
# TODO(titaiwang): Investigate why ORT fails without optimization
onnx_program.optimize()
onnx_testing.assert_onnx_program(onnx_program)
def _prepare_llm_model_gptj_to_test() -> (
tuple[
torch.nn.Module,
dict[str, Any],
dict[str, dict[int, str]],
list[str],
list[str],
]
):
model = transformers.GPTJForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-gptj"
)
batch_size = 2
input_seq_len = 16
mask_seq_len = 32
active_prob = 0.5
vocab_size = 1000
# Generate random input_ids with values between 0 and vocab_size-1
input_ids = torch.randint(100, vocab_size, (batch_size, input_seq_len))
# Generate random attention_mask with values 0 or 1, where 1 indicates an active token
attention_mask = torch.bernoulli(
torch.full((batch_size, mask_seq_len), active_prob)
).int()
position_ids = torch.tensor(
[
[1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
]
)
past_key_values = [
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
(torch.randn(2, 4, 16, 8), torch.randn(2, 4, 16, 8)),
]
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
}
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"past_key_values.0.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.0.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.1.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.1.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.2.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.2.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.3.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.3.value": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.4.key": {0: "batch_size", 2: "past_sequence_length"},
"past_key_values.4.value": {0: "batch_size", 2: "past_sequence_length"},
"attention_mask": {
0: "batch_size",
1: "past_sequence_length + sequence_length",
},
"position_ids": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "batch_size", 1: "sequence_length"},
"present.0.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.0.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.1.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.1.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.2.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.2.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.3.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.3.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.4.key": {0: "batch_size", 2: "past_sequence_length + sequence_length"},
"present.4.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
}
input_names = [
"input_ids",
"past_key_values.0.key",
"past_key_values.0.value",
"past_key_values.1.key",
"past_key_values.1.value",
"past_key_values.2.key",
"past_key_values.2.value",
"past_key_values.3.key",
"past_key_values.3.value",
"past_key_values.4.key",
"past_key_values.4.value",
"attention_mask",
"position_ids",
]
output_names = [
"logits",
"present.0.key",
"present.0.value",
"present.1.key",
"present.1.value",
"present.2.key",
"present.2.value",
"present.3.key",
"present.3.value",
"present.4.key",
"present.4.value",
]
return model, kwargs, dynamic_axes, input_names, output_names
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -6,7 +6,6 @@ from __future__ import annotations
import inspect import inspect
import logging import logging
import re
import warnings import warnings
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
@ -60,7 +59,9 @@ def _rename_dynamic_shapes_with_model_inputs(
renamed_dynamic_shapes = {} renamed_dynamic_shapes = {}
for idx, param_name in enumerate(sig.parameters): for idx, param_name in enumerate(sig.parameters):
renamed_dynamic_shapes[param_name] = dynamic_shapes[input_names[idx]] input_name = input_names[idx]
if input_name in dynamic_shapes:
renamed_dynamic_shapes[param_name] = dynamic_shapes[input_name]
return renamed_dynamic_shapes return renamed_dynamic_shapes
@ -73,22 +74,24 @@ def _from_dynamic_axes_to_dynamic_shapes(
dynamic_axes=None, dynamic_axes=None,
output_names: set[str], output_names: set[str],
input_names: Sequence[str] | None = None, input_names: Sequence[str] | None = None,
) -> dict[str, Any | None] | None: ) -> tuple[dict[str, Any | None] | None, tuple[Any, ...], dict[str, Any] | None]:
""" """
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with torch.export.Dim.AUTO.
dynamic_axes examples: dynamic_axes examples:
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
(2) dynamic_axes = {"x": [0], "y": [1]} (2) dynamic_axes = {"x": [0], "y": [1]}
these will be converted to dynamic_shapes respectively: these will be converted to dynamic_shapes respectively:
(1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} (1) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
(2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names (2) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
Detail on Dim.AUTO: https://github.com/pytorch/pytorch/pull/133620
""" """
# https://github.com/pytorch/pytorch/pull/128371 # https://github.com/pytorch/pytorch/pull/128371
# 1. The function does not need to provide dynamic_shapes to torch.export.export # 1. The function does not need to provide dynamic_shapes to torch.export.export
if dynamic_axes is None: if dynamic_axes is None:
return None return None, args, kwargs
if input_names is None: if input_names is None:
input_names = [] input_names = []
@ -98,25 +101,27 @@ def _from_dynamic_axes_to_dynamic_shapes(
dynamic_shapes: dict[str, Any | None] = {} dynamic_shapes: dict[str, Any | None] = {}
for input_name, axes in dynamic_axes.items(): for input_name, axes in dynamic_axes.items():
# NOTE: torch.export.Dim requires strict min and max constraints, and it # TODO(titaiwang): Add ONNX IR pass to rename default dynamic axes: s0, s1, ...
# dpends on the traced model to provide the correct min and max values. # to the dynamic axes defined by users.
# We set max to 99999 to avoid the constraints violation error with the default int64 max. # NOTE: torch.export.Dim.AUTO does its best to infer the min and max values
# https://github.com/pytorch/pytorch/blob/32f585d9346e316e554c8d9bf7548af9f62141fc/test/export/test_export.py#L687 # from the model, but it's not guaranteed to be dynamic.
if input_name in output_names: if input_name in output_names:
# User specified an output name as a dynamic axis, so we skip it # User specified an output name as a dynamic axis, so we skip it
continue continue
if isinstance(axes, dict): if isinstance(axes, dict):
# Dim needs to pass str.isidentifier() if any(not isinstance(k, int) for k in axes.keys()):
# If the max is not set, llm is going to fail, as sequence length is usually bounded within config. raise ValueError(
# But we also don't want to only support llm. This kind of leaves us with this awkward position. "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
)
dynamic_shapes[input_name] = { dynamic_shapes[input_name] = {
k: torch.export.Dim(re.sub(r"[^A-Za-z_]", "", v), max=99999) k: torch.export.Dim.AUTO for k, _ in axes.items()
for k, v in axes.items()
} }
elif isinstance(axes, list): elif isinstance(axes, list):
dynamic_shapes[input_name] = { if any(not isinstance(k, int) for k in axes):
k: torch.export.Dim(f"{input_name}_dim_{k}", max=99999) for k in axes raise ValueError(
} "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
)
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}
elif axes is None: elif axes is None:
dynamic_shapes[input_name] = None dynamic_shapes[input_name] = None
else: else:
@ -139,7 +144,10 @@ def _from_dynamic_axes_to_dynamic_shapes(
# We need tree structure to represent dynamic_shapes # We need tree structure to represent dynamic_shapes
dynamic_shapes = _unflatten_dynamic_shapes_with_inputs_tree(inputs, dynamic_shapes) dynamic_shapes = _unflatten_dynamic_shapes_with_inputs_tree(inputs, dynamic_shapes)
return dynamic_shapes
# Since the dynamic_shapes are now in the order of the model parameters,
# we need to convert args and kwargs to the order of the model parameters.
return dynamic_shapes, tuple(inputs), {}
def _unflatten_dynamic_shapes_with_inputs_tree( def _unflatten_dynamic_shapes_with_inputs_tree(
@ -266,7 +274,7 @@ def export_compat(
UserWarning, UserWarning,
) )
try: try:
dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( dynamic_shapes, args, kwargs = _from_dynamic_axes_to_dynamic_shapes(
model, model,
args, args,
kwargs, kwargs,