Skip superfluous storage allocations while constructing meta tensors (#65331)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65331

ghstack-source-id: 148862595

This is a performance optimization for the use case:

```
tensor = torch.tensor(<large_data>, device='meta')
```

where the current implementation requires a superfluous memory allocation on CPU even though the target device is a meta.

Test Plan: Run existing tests since no behavioral change is introduced.

Reviewed By: ezyang

Differential Revision: D31055036

fbshipit-source-id: 04d6c13594a71fc65bf2fbd567ee71833a879851
(cherry picked from commit 489d0a151a)
This commit is contained in:
Can Balioglu 2022-02-11 04:48:27 -08:00 committed by PyTorch MergeBot
parent 4737ae7a16
commit 6942fccf60
2 changed files with 14 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from itertools import product
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes)
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
from torch.testing._internal.common_dtype import get_all_dtypes
# For testing handling NumPy objects and sending tensors to / accepting
@ -228,6 +228,7 @@ class TestNumPyInterop(TestCase):
x.strides = (3,)
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
@skipMeta
def test_from_list_of_ndarray_warning(self, device):
warning_msg = r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
with self.assertWarnsOnceRegex(UserWarning, warning_msg):

View File

@ -266,7 +266,10 @@ Tensor internal_new_from_data(
}
#endif
auto device = device_opt.has_value() ? *device_opt : options.device();
auto sizes = compute_sizes(data, scalar_type);
ScalarType inferred_scalar_type = type_inference ? infer_scalar_type(data) : scalar_type;
// This exists to prevent us from tracing the call to empty(). The actual
// autograd code doesn't really matter, because requires_grad is always false
@ -298,15 +301,21 @@ Tensor internal_new_from_data(
tensor.set_(storage);
} else {
tensor = at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type).pinned_memory(pin_memory));
if (c10::multiply_integers(tensor.sizes()) !=0 ) {
TensorOptions opts = at::initialTensorOptions().dtype(inferred_scalar_type);
// If the device is Meta, take the shortcut. We don't want to allocate an
// empty CPU tensor which would break our contract for meta tensors.
if (device == at::kMeta) {
return at::empty(sizes, opts.device(device));
}
tensor = at::empty(sizes, opts.pinned_memory(pin_memory));
if (c10::multiply_integers(tensor.sizes()) != 0) {
recursive_store(
(char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
inferred_scalar_type, tensor.dtype().itemsize(), data);
}
}
}
auto device = device_opt.has_value() ? *device_opt : options.device();
pybind11::gil_scoped_release no_gil;
maybe_initialize_cuda(device);
// However, it is VERY important that we trace the to() call here (even