Commit Graph

31 Commits

Author SHA1 Message Date
Rohit Singh Rathaur
2bcd892c86 [distributed] Replace assert statements in distributed checkpoint with explicit checks (#165256)
Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256
Approved by: https://github.com/albanD
2025-10-17 20:14:35 +00:00
Maggie Moss
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
Yuanyuan Chen
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
Xuehai Pan
f903bc475c [BE] add noqa for flake8 rule B036: found except BaseException without re-raising (#159043)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043
Approved by: https://github.com/Skylion007
2025-07-25 02:56:34 +00:00
Xuehai Pan
4ccc0381de [BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156315
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #156313, #156314
2025-06-23 02:57:28 +00:00
PyTorch MergeBot
145d4cdc11 Revert "[BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)"
This reverts commit c2f0292bd5.

Reverted https://github.com/pytorch/pytorch/pull/156315 on behalf of https://github.com/atalman due to export/test_torchbind.py::TestCompileTorchbind::test_compile_error_on_input_aliasing_contents_backend_aot_eager [GH job link](https://github.com/pytorch/pytorch/actions/runs/15804799771/job/44548489912) [HUD commit link](c95f7fa874) ([comment](https://github.com/pytorch/pytorch/pull/156313#issuecomment-2994171213))
2025-06-22 12:31:57 +00:00
Xuehai Pan
c2f0292bd5 [BE][5/16] fix typos in torch/ (torch/distributed/) (#156315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156315
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #156313, #156314
2025-06-22 08:43:26 +00:00
Aaron Gokaslan
d859e65826 [DCP][Ez]: Fix broadcast_object bug in DCP utils (#155912)
Fixes #152310. Broadcast_object is now symmetric with gather_object and scatter_object. It was likely a typo that wasn't fixed in https://github.com/pytorch/pytorch/pull/147675

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155912
Approved by: https://github.com/ezyang
2025-06-14 12:14:14 +00:00
Saurabh Mishra
7d2411d30e [DCP][OSS] Introduce barrier util in the DistWrapper for rank local checkpointing (#150748)
Summary: Introduce barrier util in the DistWrapper for rank local checkpointing. This barrier will be used at the end of the rank local checkpointing to ensure all ranks synchronize.

Test Plan: UTs

Differential Revision: D72541431

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150748
Approved by: https://github.com/MeetVadakkanchery
2025-04-07 17:33:07 +00:00
lanzongwei.lan
3d62e81a1e [DCP] fix dcp gather_object/scatter_object_list (#147675)
gather_object/scatter_object_list's dst is `Destination rank on global process group (regardless of group argument)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147675
Approved by: https://github.com/MeetVadakkanchery
2025-03-06 21:20:38 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Ke Wen
762a05b3b3 [DCP] Remove all-gather of state dict keys (#145998)
The original `_all_gather_keys` call was for a safety check, but could be costly as things scale, and it blocks CPU.

Instead, we make it clear in the documentation that the `state_dict` passed to the `load` API should have same set of keys, otherwise the API may hang.

In addition, we move the check to a utility function: `utils.assert_same_keys`. User uncertain about state dict unity can optionally call this API to check.

Resolves #145965 (as a workaround).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145998
Approved by: https://github.com/mhorowitz, https://github.com/fegin
2025-02-04 03:16:13 +00:00
Aaron Orenstein
316808e4e9 PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
2025-01-19 20:55:59 +00:00
Marc Horowitz
95d333f52e [distributed] Fix _ReaderView.read() and readinto() to stop reading at the end of the slice (#143357)
_ReaderView doesn't work correctly if the slice ends past the view.

read(-1) would call read(-1) on the base_stream, which would consume the entire underlying stream, even if the view ended before that.
read(n) would read n bytes, even if the view ended before that.

The new implementation clamps the size read to the size of the view.

readinto(b) would read len(b) bytes, even if the view ended before that.

Since the interface depends on the size of b, we use a (potentially) shortened view into b to avoid a copy.  If the view doesn't contain enough data to fill the view, then this will appear as end of stream to the caller, which is the desired behavior.

This fix should not be user facing, since the bug is in an internal helper, and is only visible with new code down the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143357
Approved by: https://github.com/saumishr
2025-01-11 00:22:10 +00:00
Xuehai Pan
b77406a9ec [BE][CI] bump ruff to 0.8.4 (#143753)
Changes:

1. Bump `ruff` from 0.7.4 to 0.8.4
2. Change `%`-formatted strings to f-string
3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753
Approved by: https://github.com/Skylion007
2024-12-24 12:24:10 +00:00
Saurabh Mishra
dd7cd182ab [AIInfra][DCP] All gather keys checkpoint utils bug fix (#135045)
Summary: All gather keys checkpoint utils bug fix. Dist. get_world_size should have the process group passed in to avoid inconsistent world size in case the process group has changed. This is common in the tests.

Test Plan: UTs

Reviewed By: Saiteja64

Differential Revision: D61578832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135045
Approved by: https://github.com/MeetVadakkanchery, https://github.com/LucasLLC
2024-09-04 18:49:34 +00:00
Zain Huda
0acd09aecd [torchrec][pt-d][model store] introduce LocalShardsWrapper for DTensor (#129150)
Summary:
Same as D57688538, recreated because of GH issues

This diff introduces LocalShardsWrapper which is crucial to migrating from using ShardedTensor to DTensor in TRec state dict representation. As well as any changes needed in PT-D and ModelStore to support this.

It allows us to extend DTensor to support multiple shards on a rank as well as empty shards on a rank as needed by TRec sharding logic.

This diff also extends the support for LocalShardsWrapper to be used in conjunction with DTensor in checkpointing cases (ModelStore and DCP)

See D54375878 for how it is used.

**LocalShardsWrapper supports the following torch ops:**
+ torch.ops._c10d_functional.all_gather_into_tensor.default
+ aten._to_copy.default
+ aten.view.default
+ aten.equal.default
+ aten.detach.default

With extensibility to add more as required by use cases.

See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach.

NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512

Test Plan:
` buck test mode/opt -c python.package_style=inplace aiplatform/modelstore/client/tests_gpu:dist_checkpoint_save_load_with_stateful_tests -- --print-passing-details`

`buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_tensor_configs -- --print-passing-details`

Sandcastle

Reviewed By: XilunWu, wanchaol

Differential Revision: D58570479

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129150
Approved by: https://github.com/XilunWu
2024-06-21 01:58:51 +00:00
Xuehai Pan
e6d4451ae8 [BE][Easy] enable UFMT for torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/ (#128866)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866
Approved by: https://github.com/fegin
2024-06-18 13:51:53 +00:00
Aaron Orenstein
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
Zain Huda
6d4ec9b2ec [RFC] Introduce Checkpointable for DCP (#127540) (#127628)
Summary:
# Introduce Checkpointable interface for DCP to support arbitrary tensor subclasses for checkpointing

**Authors:**
* zainhuda

## **Summary**
This diff adds a CheckpointableTensor interface to allow for future compatibility for any tensor subclass with DCP in a clean and maintainable way.

## **Motivation**
For TorchRec sharding migration from ShardedTensor to DTensor, we create a tensor subclass that is stored by DTensor to support TorchRec's sharding schemes (ex, empty shards, multiple shards on a rank).

## **Proposed Implementation**
View the CheckpointableTensor interface implementation, in which, we introduce the minimal set of methods needed to be compatible with DCP. These methods are expected to implemented by any tensor subclasses and as such are then checkpointable by DCP.

## **Drawbacks**
No drawbacks, it extends functionality in a clean and maintainable way.

## **Alternatives**
Alternative design was creating paths for checking for certain attributes in tensor subclasses which can get messy and hard to maintain/understand why it was there in the first place.

Test Plan:
Sandcastle

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC

Differential Revision: D57970603

Pulled By: iamzainhuda

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127628
Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/fegin
2024-06-03 21:21:55 +00:00
Chien-Chin Huang
644bc69530 [DCP] Allow users to save and load without creating storage reader and writer (#117772)
Right now DCP API requires users to create StorageWriter and StorageReader for every API call. This PR allows users to only pass the checkpointer_id (a path) and use it to read/write a checkpoint without creating a StorageReader and Writer.

Differential Revision: [D52740556](https://our.internmc.facebook.com/intern/diff/D52740556/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117772
Approved by: https://github.com/wz337
ghstack dependencies: #116248
2024-01-26 09:08:35 +00:00
Lucas Pasqualin
b10b08227a Passes process group to _all_gather_keys in dcp.load (#118301)
As title

Fixes #118277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118301
Approved by: https://github.com/Skylion007, https://github.com/fegin
2024-01-25 23:07:57 +00:00
Chien-Chin Huang
f170d6665c [DCP] Add a profiler function for benchmarking save and load (#116007)
Many operations when calling DCP's save and load are executed on CPU. Thus we can easily profile these operations with cProfile. This PR adds the ability to profile the save() and load()

One follow-up for this PR is to integrate the feature with the distributed logging flags.

Differential Revision: [D52245434](https://our.internmc.facebook.com/intern/diff/D52245434/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116007
Approved by: https://github.com/LucasLLC, https://github.com/wz337
ghstack dependencies: #116006
2023-12-21 08:03:07 +00:00
Chien-Chin Huang
db8d409d08 [DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)
No logic change. Just typing and ufmt.

Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
2023-12-13 10:32:36 +00:00
Lucas Pasqualin
753c07bbe0 All gather keys before processing Stateful objects in save/load [2/N] (#114304)
Accounts for the case where `state_dict` keys may present in different orders. Since users may be calling collectives in `state_dict` and `load_state_dict` call, different ordered keys could cause a deadlock. This is mostly a defensive move, meant to match the feature in TSS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114304
Approved by: https://github.com/fegin, https://github.com/wz337
2023-12-04 18:31:14 +00:00
NVS Abhilash
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
dilililiwhy
ff37f6018d Enable custom device support in fsdp checkpoint (#107289)
Fixes https://github.com/pytorch/pytorch/issues/104390
Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107289
Approved by: https://github.com/wz337
2023-08-25 11:50:03 +00:00
Rodrigo Kumpera
4833dc10b8 [DCP] Rewrite read slicing to use a wrapper. (#99167)
Moved SlicedBufferedReader to utils and renamed to _ReaderView.

It no longer depends on file handles and is a pure wrapper. This makes it general enought to handle non io stream objects like fsspec's.

Should help with #98386
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99167
Approved by: https://github.com/wz337
2023-06-08 13:52:13 +00:00
Iris
bb347dc3c3 [PTD][DCP] Add 1D DTensor based DCP (#94868)
Add 1D DTensor based DCP along with its test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94868
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-02-16 23:38:04 +00:00
Iris
22e7514a15 [Checkpoint][2D][3/N] Add nested_tensors for distributed checkpoint to core distributed (#89501)
This PR moves nested_tensors to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.

This flattens sharded tensors in state_dict. It is used when saving and loading FSDP SHARDED_STATE_DICT.

Docstring, individual and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89501
Approved by: https://github.com/wanchaol
2022-11-28 23:21:38 +00:00
Iris
aee96bbf5a [PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698)
Context in RFC: https://github.com/pytorch/pytorch/issues/86620

.rst file will be finalized in subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698
Approved by: https://github.com/wanchaol
2022-11-16 21:06:38 +00:00