[functional] Avoid duplicate custom get_device call in constructor (#162889)

Trying to reduce the number of `__torch_dispatch__` calls of FakeTensorMode in the AOT metadata collection pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162889
Approved by: https://github.com/Lucaskabela, https://github.com/zou3519
This commit is contained in:
Animesh Jain 2025-09-15 12:00:45 -07:00 committed by PyTorch MergeBot
parent b68a5115a4
commit 9009c4da39

View File

@ -102,7 +102,7 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
// SparseTensorImpl has no storage, so we cannot query its nbytes.
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
// Same for XLA
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
if (base.unsafeGetTensorImpl()->has_storage() && data_ptr().device().type() != c10::DeviceType::XLA) {
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
} else {
original_storage_size_ = -1;