mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Eventually, we should just have one unified way to check for parity between a `DTensor`-sharded model and a replicated model. This PR is a small refactor to work toward that. One current gap to use this `check_sharded_parity` function for 2D is that FSDP's `(Shard(0), Shard(0))` layout differs from that of the `DTensor` APIs since FSDP shards on dim-0 after TP shards on dim-0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121357 Approved by: https://github.com/weifengpy ghstack dependencies: #121360 |
||
|---|---|---|
| .. | ||
| fsdp | ||
| fully_shard | ||
| test_checkpoint.py | ||
| test_compose.py | ||
| test_contract.py | ||
| test_replicate_with_compiler.py | ||
| test_replicate.py | ||