mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[functional] Avoid duplicate custom get_device call in constructor (#162889)
Trying to reduce the number of `__torch_dispatch__` calls of FakeTensorMode in the AOT metadata collection pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162889 Approved by: https://github.com/Lucaskabela, https://github.com/zou3519
This commit is contained in:
parent
b68a5115a4
commit
9009c4da39
|
|
@ -102,7 +102,7 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
|
|||
// SparseTensorImpl has no storage, so we cannot query its nbytes.
|
||||
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
|
||||
// Same for XLA
|
||||
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
|
||||
if (base.unsafeGetTensorImpl()->has_storage() && data_ptr().device().type() != c10::DeviceType::XLA) {
|
||||
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
|
||||
} else {
|
||||
original_storage_size_ = -1;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user