mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix torch export with dict input nested in args (#162618)
Investigated together with @pyemma and @taotaohuang001
## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
raise e
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
raise ValueError( # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
*])]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
*])]),
TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.
```
## How to reproduce the issue
```python
import torch
# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, data_batch):
h1 = self.linear(data_batch["a1"])
h2 = self.linear(data_batch["a2"])
return h1 + h2
# torch export this module
model = MyModel()
example_args_forward = (
{
"a1": torch.randn(10),
"a2": torch.randn(10),
},
)
exported_model = torch.export.export(model, example_args_forward, strict=True)
# save the exported model
torch.export.save(exported_model, "exported_model.pt2")
# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()
# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
```
## Root Cause
Input spec is encoded as [TreeSpec](582d278983/torch/utils/_pytree.py (L1059)) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](582d278983/torch/export/_unlift.py (L66)) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like
TreeSpec(dict, ['a2', 'a1'], [*,*])
To workaround this, the input check reorders [kwargs](582d278983/torch/export/_unlift.py (L67)), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.
## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162618
Approved by: https://github.com/angelayi
This commit is contained in:
parent
7dd5f7b125
commit
543d50db2b
|
|
@ -16605,6 +16605,37 @@ def forward(self, x):
|
|||
wrapper = Wrapper(pyt_model, example_inputs)
|
||||
wrapper.forward()
|
||||
|
||||
def test_export_with_dict_input_nested_in_args(self):
|
||||
"""Test export with dictionary input nested in args."""
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.linear = torch.nn.Linear(10, 1)
|
||||
|
||||
def forward(self, data_batch):
|
||||
h1 = self.linear(data_batch["a1"])
|
||||
h2 = self.linear(data_batch["a2"])
|
||||
return h1 + h2
|
||||
|
||||
# Create model and example inputs
|
||||
model = MyModel()
|
||||
a1 = torch.randn(10)
|
||||
a2 = torch.randn(10)
|
||||
original_input = {"a1": a1, "a2": a2}
|
||||
example_args_forward = (original_input,)
|
||||
|
||||
# Export the model
|
||||
exported_model = export(model, example_args_forward)
|
||||
|
||||
# Run both models and compare results
|
||||
reordered_input = {"a2": a2, "a1": a1}
|
||||
original_output = exported_model.module()(reordered_input)
|
||||
loaded_output = model(original_input)
|
||||
|
||||
# Verify outputs are close (allowing for floating point differences)
|
||||
torch.testing.assert_close(original_output, loaded_output)
|
||||
|
||||
def test_strict_export_with_shared_parameters(self):
|
||||
"""Test that parameter names are preserved when there are shared parameters with the same name."""
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
|
|||
return True
|
||||
if _normalize_type(a.type) != _normalize_type(b.type):
|
||||
return False
|
||||
if a.context != b.context:
|
||||
if a.type is dict and b.type is dict:
|
||||
# in the case of dict, the context is list of keys and we allow the keys to be in any order
|
||||
if set(a.context) != set(b.context):
|
||||
return False
|
||||
elif a.context != b.context:
|
||||
return False
|
||||
if len(a.children_specs) != len(b.children_specs):
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user