pytorch/torch/distributed/_shard
Gufan Yin 5d963474aa Replace enforce_dtype with dtype in ShardedTensor.gather (#110561)
Summary:
Sometimes local_shards are empty on some ranks, and out.dtype is float16, which will cause error if enforce_dtype is True because `data` will be float32.

Callers know best what dtype they want, so we can just let callers decide.

Temporarily keep enforce_dtype for backward compatibility

Test Plan: Run local and MAST job

Reviewed By: uciyc123

Differential Revision: D46886551

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110561
Approved by: https://github.com/wanchaol, https://github.com/malfet
2023-10-05 23:16:23 +00:00
..
checkpoint
sharded_optim
sharded_tensor Replace enforce_dtype with dtype in ShardedTensor.gather (#110561) 2023-10-05 23:16:23 +00:00
sharding_plan
sharding_spec fix: flake8-bugbear code B024 (#107265) 2023-10-04 23:52:52 +00:00
__init__.py
_utils.py
api.py
common_op_utils.py
metadata.py
op_registry_utils.py
sharder.py