pytorch/torch
Chien-Chin Huang c4fc5d372f [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900)
This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`.

Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`.

The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900
Approved by: https://github.com/rohan-varma
2022-11-11 03:41:40 +00:00
..
_C add DisableTorchFunction that matches DisableTorchDispatch (#88219) 2022-11-10 14:51:13 +00:00
_C_flatbuffer
_decomp [primTorch] Implement group norm reference (#87054) 2022-11-11 01:08:20 +00:00
_dispatch
_dynamo [Dynamo] Add complete support for Tensor.is_contiguous (#88407) 2022-11-10 23:47:21 +00:00
_inductor Assert we have triton before scheduling on triton (#88849) 2022-11-11 02:30:29 +00:00
_lazy
_prims Add min cut partitioner for AOT+nvFuser (#88204) 2022-11-09 12:56:55 +00:00
_prims_common Ref for aten.full; symint changes in prim (#88762) 2022-11-11 02:32:09 +00:00
_refs Ref for aten.full; symint changes in prim (#88762) 2022-11-11 02:32:09 +00:00
_subclasses rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) 2022-11-10 14:51:13 +00:00
amp
ao [ao] qconfig_mapping.py fixing public v private (#87518) 2022-11-11 00:32:24 +00:00
autograd Make Python op registration work with torchdeploy/multipy (#87162) 2022-11-03 12:56:44 +00:00
backends Add mem efficient backend flag (#87946) 2022-10-28 15:51:10 +00:00
contrib
cpu
csrc [ONNX] Improve diagnostic message formatting (#87830) 2022-11-10 21:42:17 +00:00
cuda Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
distributed [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900) 2022-11-11 03:41:40 +00:00
distributions Fix typos used in documents under torch directory (#88483) 2022-11-08 01:33:36 +00:00
fft Fix typos used in documents under torch directory (#88300) 2022-11-02 09:38:13 +00:00
futures
fx Symbolic shape: sym_floor , sym_sqrt, sym_int (#88760) 2022-11-10 23:41:33 +00:00
jit prepare removal of deprecated functionality in torch.testing (#87969) 2022-11-02 14:04:48 +00:00
legacy
lib
linalg Fix typos used in documents under torch directory (#88300) 2022-11-02 09:38:13 +00:00
masked rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) 2022-11-10 14:51:13 +00:00
monitor
multiprocessing Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
nested Implement a constructor for nested_tensor that is similar to torch.tensor() (#88213) 2022-11-08 00:03:18 +00:00
nn Disable check for dropout in MultiheadAttention fast_path (#88831) 2022-11-11 03:34:57 +00:00
onnx [ONNX] Improve diagnostic message formatting (#87830) 2022-11-10 21:42:17 +00:00
optim Publicly expose _LRScheduler to LRScheduler (#88503) 2022-11-07 21:15:10 +00:00
package Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
profiler [Profiler] Memory profiler part 1: Gradient identification (#86802) 2022-11-08 23:53:13 +00:00
quantization [ao] fuser_method_mappings.py fixing public v private (#87516) 2022-11-10 21:37:31 +00:00
signal Reimplement Kaiser window (#87330) 2022-10-27 21:01:01 +00:00
sparse
special
testing Ref for aten.full; symint changes in prim (#88762) 2022-11-11 02:32:09 +00:00
utils [DataPipe] Deprecating drop_empty_batches from Filter and other functional APIs (#88693) 2022-11-10 19:54:22 +00:00
__config__.py
__future__.py
__init__.py Revert "Add nondeterministic error for scatter (#88244)" 2022-11-10 23:56:49 +00:00
_appdirs.py
_classes.py
_deploy.py Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
_jit_internal.py
_linalg_utils.py
_lobpcg.py Fix typos used in documents under torch directory (#88300) 2022-11-02 09:38:13 +00:00
_lowrank.py
_meta_registrations.py Add meta support for scalar_tensor and argmax (#88590) 2022-11-11 01:31:00 +00:00
_namedtensor_internals.py
_ops.py OpOverload is_view (#88722) 2022-11-09 19:03:12 +00:00
_python_dispatcher.py
_six.py
_sources.py
_storage_docs.py
_tensor_docs.py Revert "[primTorch] Improve narrow and narrow_copy: refs, tests, docs (#87045)" 2022-11-09 20:48:32 +00:00
_tensor_str.py Disable Current Modes when printing Tensor (#88344) 2022-11-04 00:45:35 +00:00
_tensor.py rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) 2022-11-10 14:51:13 +00:00
_torch_docs.py Revert "[primTorch] Improve narrow and narrow_copy: refs, tests, docs (#87045)" 2022-11-09 20:48:32 +00:00
_utils_internal.py
_utils.py [fix] MathBits: serialization (#88182) 2022-11-09 17:15:12 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py Revert "[fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)" 2022-11-07 18:51:16 +00:00
abi-check.cpp
CMakeLists.txt
custom_class_detail.h
custom_class.h
extension.h
functional.py Fix typos used in documents under torch directory (#88300) 2022-11-02 09:38:13 +00:00
hub.py
library.h Make Python op registration work with torchdeploy/multipy (#87162) 2022-11-03 12:56:44 +00:00
library.py Make Python op registration work with torchdeploy/multipy (#87162) 2022-11-03 12:56:44 +00:00
overrides.py Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
py.typed
quasirandom.py
random.py
README.txt
return_types.py
script.h
serialization.py Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
storage.py Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
torch_version.py
types.py

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.