mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136045 Approved by: https://github.com/angelayi |
||
|---|---|---|
| .. | ||
| experimental | ||
| passes | ||
| __init__.py | ||
| _remove_auto_functionalized_pass.py | ||
| _remove_effect_tokens_pass.py | ||
| _safeguard.py | ||
| _trace.py | ||
| _tree_utils.py | ||
| _unlift.py | ||
| custom_obj.py | ||
| dynamic_shapes.py | ||
| exported_program.py | ||
| graph_signature.py | ||
| unflatten.py | ||