mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.** (also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552) ## Summary This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id. Optimizers can be initialized with `named_parameters()` as: ```python optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) ``` This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as: ``` state_dict = { 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0,1] 'param_names' ['layer.weight', 'layer.bias'] (optional) } ] } ``` Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored. ## Key Features #### Named Parameters in Optimizer Initialization: Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly. #### Parameter Names in `state_dict`: The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters. ## Backward Compatibility #### No Breaking Changes: This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer. #### Customization with Hooks: For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs. ## Documentation Updates Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively. ## Solution Example: A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order. The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict : ```python def adapt_state_dict_ids(optimizer, state_dict): # assuming a single param group. current_state_group = optimizer.state_dict()['param_groups'][0] loaded_state_group = state_dict['param_groups'][0] # same number of params, same names, only different ordering current_state_name_to_id_mapping = {} # mapping -- param_name: id for i, name in enumerate(current_state_group['param_names']): current_state_name_to_id_mapping[name] = current_state_group['params'][i] # changing the ids of the loaded state dict to match the order of the given state dict. for i, name in enumerate(current_state_group['param_names']): loaded_state_group['params'][i] = current_state_name_to_id_mapping[name] return state_dict ``` In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`. Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict. ### Note This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| cond.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cuda.tunable.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| debugging_environment_variables.rst | ||
| deploy.rst | ||
| deterministic.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.checkpoint.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.pipelining.rst | ||
| distributed.rst | ||
| distributed.tensor.parallel.rst | ||
| distributed.tensor.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| export.ir_spec.rst | ||
| export.rst | ||
| fft.rst | ||
| fsdp.rst | ||
| func.api.rst | ||
| func.batch_norm.rst | ||
| func.migrating.rst | ||
| func.rst | ||
| func.ux_limitations.rst | ||
| func.whirlwind_tour.rst | ||
| future_mod.rst | ||
| futures.rst | ||
| fx.experimental.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| logging.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| meta.rst | ||
| miscellaneous_environment_variables.rst | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| module_tracker.rst | ||
| monitor.rst | ||
| mps_environment_variables.rst | ||
| mps.rst | ||
| mtia.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.attention.bias.rst | ||
| nn.attention.flex_attention.rst | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_dynamo_onnxruntime_backend.rst | ||
| onnx_dynamo.rst | ||
| onnx_torchscript_supported_aten_ops.rst | ||
| onnx_torchscript.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| signal.rst | ||
| size.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| threading_environment_variables.rst | ||
| torch_cuda_memory.rst | ||
| torch_environment_variables.rst | ||
| torch_nccl_environment_variables.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.compiler_aot_inductor.rst | ||
| torch.compiler_api.rst | ||
| torch.compiler_best_practices_for_backends.rst | ||
| torch.compiler_cudagraph_trees.rst | ||
| torch.compiler_custom_backends.rst | ||
| torch.compiler_dynamic_shapes.rst | ||
| torch.compiler_dynamo_deepdive.rst | ||
| torch.compiler_dynamo_overview.rst | ||
| torch.compiler_fake_tensor.rst | ||
| torch.compiler_faq.rst | ||
| torch.compiler_fine_grain_apis.rst | ||
| torch.compiler_get_started.rst | ||
| torch.compiler_inductor_profiling.rst | ||
| torch.compiler_ir.rst | ||
| torch.compiler_nn_module.rst | ||
| torch.compiler_performance_dashboard.rst | ||
| torch.compiler_profiling_torch_compile.rst | ||
| torch.compiler_transformations.rst | ||
| torch.compiler_troubleshooting.rst | ||
| torch.compiler.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||
| utils.rst | ||
| xpu.rst | ||