pytorch/torch/distributed/fsdp
ankurneog e248c1d7eb Update real device in FSDP state_dict_utils (#134994)
## Motivation
The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor.

```
[rank3]   File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical
[rank3]     sharded_tensor_sd = ref_model.state_dict()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict
[rank3]     hook_result = hook(self, destination, prefix, local_metadata)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]     return func(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook
[rank3]     tensor.device,
[rank3]   File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper
[rank3]     return arg(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__
[rank3]     return dispatch(st_instance, func)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch
[rank3]     return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper
[rank3]     return wrapped_func(types, args, kwargs, process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device
[rank3]     dev = torch.device(torch.cuda.current_device())
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device
[rank3]     _lazy_init()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init
[rank3]     raise AssertionError("Torch not compiled with CUDA enabled")
[rank3] AssertionError: Torch not compiled with CUDA enabled
````

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134994
Approved by: https://github.com/fegin
2024-09-17 04:39:08 +00:00
..
__init__.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00
_common_utils.py Integrate device agnostic APIs in FSDP library [1/n] (#134337) 2024-08-27 17:31:11 +00:00
_debug_utils.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00
_dynamo_utils.py Flip default value for mypy disallow_untyped_defs [6/11] (#127843) 2024-06-08 18:49:29 +00:00
_exec_order_utils.py Flip default value for mypy disallow_untyped_defs [6/11] (#127843) 2024-06-08 18:49:29 +00:00
_flat_param.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
_fsdp_extensions.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
_init_utils.py [BE] Raise when the target model has scalar parameters (#132934) 2024-08-12 18:28:02 +00:00
_limiter_utils.py Integrate device agnostic APIs in FSDP library [1/n] (#134337) 2024-08-27 17:31:11 +00:00
_optim_utils.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
_runtime_utils.py [FSDP1][Easy] Remove Spammy Log Lin in _runtime_utils.py (#129967) 2024-07-02 21:08:57 +00:00
_shard_utils.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
_state_dict_utils.py Update real device in FSDP state_dict_utils (#134994) 2024-09-17 04:39:08 +00:00
_trace_utils.py [BE] typing for decorators - fx/_compatibility (part 1) (#134202) 2024-08-22 17:07:33 +00:00
_traversal_utils.py
_unshard_param_utils.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00
_wrap_utils.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00
api.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00
fully_sharded_data_parallel.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
sharded_grad_scaler.py Use _amp_foreach_non_finite_check_and_unscale_ for CPU grads of ShardedGradScaler (#135232) 2024-09-14 09:53:17 +00:00
wrap.py [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869) 2024-06-18 21:49:08 +00:00