The original `_all_gather_keys` call was for a safety check, but could be costly as things scale, and it blocks CPU.
Instead, we make it clear in the documentation that the `state_dict` passed to the `load` API should have same set of keys, otherwise the API may hang.
In addition, we move the check to a utility function: `utils.assert_same_keys`. User uncertain about state dict unity can optionally call this API to check.
Resolves#145965 (as a workaround).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145998
Approved by: https://github.com/mhorowitz, https://github.com/fegin
_ReaderView doesn't work correctly if the slice ends past the view.
read(-1) would call read(-1) on the base_stream, which would consume the entire underlying stream, even if the view ended before that.
read(n) would read n bytes, even if the view ended before that.
The new implementation clamps the size read to the size of the view.
readinto(b) would read len(b) bytes, even if the view ended before that.
Since the interface depends on the size of b, we use a (potentially) shortened view into b to avoid a copy. If the view doesn't contain enough data to fill the view, then this will appear as end of stream to the caller, which is the desired behavior.
This fix should not be user facing, since the bug is in an internal helper, and is only visible with new code down the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143357
Approved by: https://github.com/saumishr
Changes:
1. Bump `ruff` from 0.7.4 to 0.8.4
2. Change `%`-formatted strings to f-string
3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753
Approved by: https://github.com/Skylion007
Summary:
Same as D57688538, recreated because of GH issues
This diff introduces LocalShardsWrapper which is crucial to migrating from using ShardedTensor to DTensor in TRec state dict representation. As well as any changes needed in PT-D and ModelStore to support this.
It allows us to extend DTensor to support multiple shards on a rank as well as empty shards on a rank as needed by TRec sharding logic.
This diff also extends the support for LocalShardsWrapper to be used in conjunction with DTensor in checkpointing cases (ModelStore and DCP)
See D54375878 for how it is used.
**LocalShardsWrapper supports the following torch ops:**
+ torch.ops._c10d_functional.all_gather_into_tensor.default
+ aten._to_copy.default
+ aten.view.default
+ aten.equal.default
+ aten.detach.default
With extensibility to add more as required by use cases.
See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach.
NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512
Test Plan:
` buck test mode/opt -c python.package_style=inplace aiplatform/modelstore/client/tests_gpu:dist_checkpoint_save_load_with_stateful_tests -- --print-passing-details`
`buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_tensor_configs -- --print-passing-details`
Sandcastle
Reviewed By: XilunWu, wanchaol
Differential Revision: D58570479
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129150
Approved by: https://github.com/XilunWu
Summary:
# Introduce Checkpointable interface for DCP to support arbitrary tensor subclasses for checkpointing
**Authors:**
* zainhuda
## **Summary**
This diff adds a CheckpointableTensor interface to allow for future compatibility for any tensor subclass with DCP in a clean and maintainable way.
## **Motivation**
For TorchRec sharding migration from ShardedTensor to DTensor, we create a tensor subclass that is stored by DTensor to support TorchRec's sharding schemes (ex, empty shards, multiple shards on a rank).
## **Proposed Implementation**
View the CheckpointableTensor interface implementation, in which, we introduce the minimal set of methods needed to be compatible with DCP. These methods are expected to implemented by any tensor subclasses and as such are then checkpointable by DCP.
## **Drawbacks**
No drawbacks, it extends functionality in a clean and maintainable way.
## **Alternatives**
Alternative design was creating paths for checking for certain attributes in tensor subclasses which can get messy and hard to maintain/understand why it was there in the first place.
Test Plan:
Sandcastle
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC
Differential Revision: D57970603
Pulled By: iamzainhuda
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127628
Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/fegin
Accounts for the case where `state_dict` keys may present in different orders. Since users may be calling collectives in `state_dict` and `load_state_dict` call, different ordered keys could cause a deadlock. This is mostly a defensive move, meant to match the feature in TSS.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114304
Approved by: https://github.com/fegin, https://github.com/wz337
Fixes: #113193
`pydocstyle <all_files_in_issue> --count`
- Before: 345
- After: 130
For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
Moved SlicedBufferedReader to utils and renamed to _ReaderView.
It no longer depends on file handles and is a pure wrapper. This makes it general enought to handle non io stream objects like fsspec's.
Should help with #98386
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99167
Approved by: https://github.com/wz337
This PR moves nested_tensors to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.
This flattens sharded tensors in state_dict. It is used when saving and loading FSDP SHARDED_STATE_DICT.
Docstring, individual and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89501
Approved by: https://github.com/wanchaol