pytorch/torch/export
Huanyu He a4e9a1c90b [TorchRec][PT2 IR][APF] short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules (#136045)
Summary:
# context
* for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/)
* basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict.
NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545}
* short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup.
* hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users.

# details
* The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module.  Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC.
* a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns.

WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`.

# additional changes
* absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`.
* set `graph.owning_module` in export.unflatten as required by the graph modification
* add one more layer of `sparse_module` for closely mimicing the APF model structure.

Test Plan:
# run test
* serializer
```
buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer
```
* apf
```
buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir'
```
* local mp run
```
==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ====
finished
  test_mtml_instagram_model_562438350_single_gpu_with_ir
Imports took: 6.0s! Profile with --import-profiler.            --_ |""---__
Executed 1 example in 203.1s:                               |'.|  ||  .    """|
  Successful: 1                                             | ||  || /|\""-.  |
  Failed: 0                                                 | ||  ||  |    |  |
  Skipped: 0                                                | ||  ||  |   \|/ |
  Not executed: 8                                           |."|  ||  --"" '__|
https://testslide.readthedocs.io/                              --" |__---"""
```

Differential Revision: D62606738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136045
Approved by: https://github.com/angelayi
2024-09-17 18:42:56 +00:00
..
experimental Deprecate _preserve_ops and consolidate with decomp_table (#135080) 2024-09-15 17:01:58 +00:00
passes [export] Make move_to_device_pass function public (#134263) 2024-08-23 23:18:30 +00:00
__init__.py Create export_for_inference API and expose core_aten as public facing API (#135912) 2024-09-15 17:05:07 +00:00
_remove_auto_functionalized_pass.py Redesign custom op functionlaization for better re-inplace (#134409) 2024-09-04 17:08:58 +00:00
_remove_effect_tokens_pass.py [export][fx] More robust DCE pass (#132764) 2024-08-06 22:27:22 +00:00
_safeguard.py Flip default value for mypy disallow_untyped_defs [6/11] (#127843) 2024-06-08 18:49:29 +00:00
_trace.py Fix decomp behaviour in export training IR (#134801) 2024-09-05 06:37:44 +00:00
_tree_utils.py
_unlift.py [export] detach constant tensors when they're not registered as buffer or parameter in unlift (#133031) 2024-08-09 20:33:52 +00:00
custom_obj.py
dynamic_shapes.py [export] ignore mark_dynamic() in export (#135536) 2024-09-12 21:22:19 +00:00
exported_program.py Create export_for_inference API and expose core_aten as public facing API (#135912) 2024-09-15 17:05:07 +00:00
graph_signature.py [export] refactor ExportGraphSignature construction (#134059) 2024-08-23 23:29:28 +00:00
unflatten.py [TorchRec][PT2 IR][APF] short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules (#136045) 2024-09-17 18:42:56 +00:00