[FSDP1] print fqns when debug FlatParamHandle (#151336)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151336
Approved by: https://github.com/awgu, https://github.com/Skylion007
This commit is contained in:
Wei Feng 2025-04-15 11:46:11 -07:00 committed by PyTorch MergeBot
parent 2a58d2a155
commit 2102b3b4c5

View File

@ -593,6 +593,9 @@ class FlatParamHandle:
)
self._use_unsharded_views(as_params=False)
def __repr__(self):
return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})"
def _init_setattr_fns(self):
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]