From 543d50db2bff4c1e19936bf652b6cbae9c0b8c7d Mon Sep 17 00:00:00 2001 From: Chen Date: Sat, 13 Sep 2025 03:24:26 +0000 Subject: [PATCH] Fix torch export with dict input nested in args (#162618) Investigated together with @pyemma and @taotaohuang001 ## Problem when calling exported module with dict nested in the args tuple, it will make following complaits ``` Traceback (most recent call last): File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__ raise e File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl return inner() File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn return fn(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match raise ValueError( # noqa: B904 ValueError: Trying to flatten user inputs with exported input tree spec: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*, *])]), TreeSpec(dict, [], [])]) but actually got inputs with tree spec of: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*, *])]), TreeSpec(dict, [], [])]). Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing. ``` ## How to reproduce the issue ```python import torch # create a nn.Module with data_batch as input and output as output class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, data_batch): h1 = self.linear(data_batch["a1"]) h2 = self.linear(data_batch["a2"]) return h1 + h2 # torch export this module model = MyModel() example_args_forward = ( { "a1": torch.randn(10), "a2": torch.randn(10), }, ) exported_model = torch.export.export(model, example_args_forward, strict=True) # save the exported model torch.export.save(exported_model, "exported_model.pt2") # load the exported model exported_model = torch.export.load("exported_model.pt2").module() # run the exported model print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) ``` ## Root Cause Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like TreeSpec(dict, ['a2', 'a1'], [*,*]) To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors. ## Solution Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162618 Approved by: https://github.com/angelayi --- test/export/test_export.py | 31 +++++++++++++++++++++++++++++++ torch/export/_unlift.py | 6 +++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 2c466f162a8..3ec52775cf0 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16605,6 +16605,37 @@ def forward(self, x): wrapper = Wrapper(pyt_model, example_inputs) wrapper.forward() + def test_export_with_dict_input_nested_in_args(self): + """Test export with dictionary input nested in args.""" + + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.linear = torch.nn.Linear(10, 1) + + def forward(self, data_batch): + h1 = self.linear(data_batch["a1"]) + h2 = self.linear(data_batch["a2"]) + return h1 + h2 + + # Create model and example inputs + model = MyModel() + a1 = torch.randn(10) + a2 = torch.randn(10) + original_input = {"a1": a1, "a2": a2} + example_args_forward = (original_input,) + + # Export the model + exported_model = export(model, example_args_forward) + + # Run both models and compare results + reordered_input = {"a2": a2, "a1": a1} + original_output = exported_model.module()(reordered_input) + loaded_output = model(original_input) + + # Verify outputs are close (allowing for floating point differences) + torch.testing.assert_close(original_output, loaded_output) + def test_strict_export_with_shared_parameters(self): """Test that parameter names are preserved when there are shared parameters with the same name.""" diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index ae4c09b7c8c..59c5ade5824 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -51,7 +51,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: return True if _normalize_type(a.type) != _normalize_type(b.type): return False - if a.context != b.context: + if a.type is dict and b.type is dict: + # in the case of dict, the context is list of keys and we allow the keys to be in any order + if set(a.context) != set(b.context): + return False + elif a.context != b.context: return False if len(a.children_specs) != len(b.children_specs): return False