mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79264 `shard_module` API didn't work correctly with a sub-pg since `dist.scatter` actually takes the global rank as input for `src`. Fixing this by passing in the appropriate rank to `dist.scatter` Differential Revision: [D37062766](https://our.internmc.facebook.com/intern/diff/D37062766/) Approved by: https://github.com/fduwjj, https://github.com/wanchaol |
||
|---|---|---|
| .. | ||
| checkpoint | ||
| sharded_optim | ||
| sharded_tensor | ||
| sharding_plan | ||
| sharding_spec | ||
| test_partial_tensor.py | ||
| test_replicated_tensor.py | ||
| test_sharder.py | ||