pytorch/torch/distributed/fsdp
wz337 d9eb5a57aa [FSDP] Change _create_chunk_dtensor in fsdp/_shard_utils.py to use public API from DTensor (#110831)
This PR:
1) updates _create_chunk_dtensor() in _shard_utils.py to use public APIs from DTensor. This will avoid the global_size calculation error from using DTensor.from_local() for uneven-sharded parameters, as described in https://github.com/pytorch/pytorch/issues/110762
2) updates test/distributed/fsdp/test_fsdp_dtensor_state_dict.py to include unit test for a model with uneven sharding.

cc. @wanchaol, @fegin

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110831
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-10-10 21:04:27 +00:00
..
__init__.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_common_utils.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_debug_utils.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_dynamo_utils.py
_exec_order_utils.py [FSDP] fix: fix for fsdp exec order pre fwd record (#110138) 2023-09-28 15:45:05 +00:00
_flat_param.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_fsdp_extensions.py [FSDP][optim_state_dict] Add device to _shard_utils.py to explicitly use the device from fsdp_state (#109631) 2023-09-20 01:59:38 +00:00
_init_utils.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_limiter_utils.py
_optim_utils.py [FSDP][optim_state_dict] Move local optimizer state to FSDP compute_device (#110929) 2023-10-10 10:34:31 +00:00
_runtime_utils.py Log usage of optimizer in backward (#110206) 2023-09-29 11:00:07 +00:00
_shard_utils.py [FSDP] Change _create_chunk_dtensor in fsdp/_shard_utils.py to use public API from DTensor (#110831) 2023-10-10 21:04:27 +00:00
_state_dict_utils.py [FSDP] Remove _set_use_dtensor in post_load_state_dict_hook (#109924) 2023-09-23 22:34:36 +00:00
_trace_utils.py [Reland] Update mypy to 1.4.1 (#105227) 2023-07-15 20:30:20 +00:00
_traversal_utils.py Migrate tuple(handle) -> handle (#104488) 2023-07-19 22:33:35 +00:00
_unshard_param_utils.py Define the public API for torch.distributed.fsdp (#109922) 2023-09-28 02:15:58 +00:00
_wrap_utils.py [FSDP][9/N] Introduce CustomPolicy (#104986) 2023-08-03 12:46:36 +00:00
api.py [FSDP][optim_state_dict] Enable cpu_offload config for optimzer state_dict (#108434) 2023-10-07 01:14:49 +00:00
fully_sharded_data_parallel.py [FSDP][optim_state_dict] Enable cpu_offload config for optimzer state_dict (#108434) 2023-10-07 01:14:49 +00:00
sharded_grad_scaler.py Improve torch.cuda.amp type hints (#108630) 2023-09-08 06:06:25 +00:00
wrap.py [FSDP][Wrap] ModuleWrapPolicy callable (#109117) 2023-09-14 07:14:18 +00:00