pytorch/torch/csrc/jit/python
tangleintel 7980ed95bd Support unpacking python dictionary in torch.jit.trace() (#81623)
# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](https://github.com/pytorch/pytorch/pull/39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81623
Approved by: https://github.com/davidberard98
2022-10-15 05:33:09 +00:00
..
init.cpp Reland 2 min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (#86797) 2022-10-13 00:31:19 +00:00
init.h
module_python.h Improve torch::jit::as_{module,object} performance (#84399) 2022-09-07 16:58:28 +00:00
pybind_utils.cpp Decay integer-only (Optional)SymIntArrayRef to IntList in IValue (#86094) 2022-10-03 20:12:32 +00:00
pybind_utils.h Support unpacking python dictionary in torch.jit.trace() (#81623) 2022-10-15 05:33:09 +00:00
pybind.h
python_arg_flatten.cpp [ONNX] Support optional type (#68793) (#73284) 2022-05-04 20:24:30 +00:00
python_arg_flatten.h
python_custom_class.cpp
python_custom_class.h
python_dict.cpp Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599) 2022-08-02 19:37:02 +00:00
python_dict.h
python_interpreter.cpp Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599) 2022-08-02 19:37:02 +00:00
python_ir.cpp [ONNX] Partially re-enable RoiAlign and RoiPool unit tests (#86169) 2022-10-13 14:39:44 +00:00
python_ir.h
python_ivalue.h Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599) 2022-08-02 19:37:02 +00:00
python_list.cpp Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599) 2022-08-02 19:37:02 +00:00
python_list.h Fix sign-compare violations in python_list.h 2022-04-01 19:15:51 +00:00
python_sugared_value.cpp Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599) 2022-08-02 19:37:02 +00:00
python_sugared_value.h [ROCm] Enable/fix unit tests test_stream_args and test_event_args (#82346) 2022-08-01 22:55:15 +00:00
python_tracer.cpp Support unpacking python dictionary in torch.jit.trace() (#81623) 2022-10-15 05:33:09 +00:00
python_tracer.h Support unpacking python dictionary in torch.jit.trace() (#81623) 2022-10-15 05:33:09 +00:00
python_tree_views.cpp Reland "Make debug_pkl smaller by only emitting unique traces." (#73368) 2022-04-18 22:34:21 +00:00
python_tree_views.h
script_init.cpp Support unpacking python dictionary in torch.jit.trace() (#81623) 2022-10-15 05:33:09 +00:00
script_init.h
update_graph_executor_opt.cpp
update_graph_executor_opt.h