mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
In https://github.com/pytorch/pytorch/pull/109977, we observed that during inference mode, aten.Linear does not get decomposed. So instead of enabling sharding propagation for linear op, we use func.decompose so that it gets decomposed to matmul and mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110751 Approved by: https://github.com/bdhirsh, https://github.com/wanchaol |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| test_ddp_2d_parallel.py | ||
| test_fsdp_2d_parallel.py | ||
| test_parallelize_api.py | ||
| test_tp_examples.py | ||
| test_tp_random_state.py | ||
| test_tp_style.py | ||
| test_view_sharding_dim_change.py | ||