From 1f36ce6e4dba4b287fcbd7c52bef736fd7c9deb0 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 29 Mar 2021 08:34:19 -0700 Subject: [PATCH] Restore storage on meta tensors; increase meta coverage (#53973) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53973 Two parts to this PR; I had to put them together because adding support for X causes more test code to be exercised, which in turn may require a fix for Y. The first part is restoring the concept of storage to meta tensors. Previously, meta tensors had a nullptr storage (e.g., `meta_tensor.storage()` is an error.) As I was increasing the coverage of meta tensors, I started running into test cases (specifically memory overlap tests) that were failing because not having storage meant I couldn't check for memory overlap. After some discussion, we decided that it would make sense for meta tensors to model this as well (we already model strides, so getting accurate view information also seems useful). This PR does that by: * Rewrite all of the factory functions in MetaTensor.cpp to use the generic versions (which are very carefully written to not actually poke at the data pointer, so everything works out). The key idea here is we give meta tensors a special allocator, MetaAllocator, which always returns a nullptr even if you ask for a nonzero number of bytes. resize_ is also made generic; the normal variant can be used directly rather than having to instruct it to avoid resizing storage * Turn on memory overlap checking in TensorIterator even for meta tensors * Although meta tensors now have storage, the concept of meta storage is NOT exposed to Python land (as it would imply I would have to codegen MetaFloatStorage, MetaDoubleStorage, etc. classes). So `x.storage()` still raises an error and I have a cludge in `__deepcopy__` to break storage sharing upon deep copy (this is wrong, but no tests exercise this at the moment). The second part is adding more support for the most used functions in the test suite. * Inplace operations have very simple meta functions. I added `fill_`, `zero_`, `random_`, `uniform_` and `normal_`. In the case of random, I take advantage of pbelevich's templates for defining random kernels, so that I can reuse the common scaffolding, and then just register a noop stub that actually does the RNG. (Look, another structured kernels tiny variant!) * `copy_` is now implemented. Copying into a meta tensor is always OK, but copying out of a meta tensor raises an error (as we don't know what the "correct" data to copy out is in this case) * `empty_strided` usage from structured kernels now is implemented (TBH, this could have been done as soon as `empty_strided` was added) * Meta was missing in a few places in TensorOptions/DispatchKey utility functions, so I added them * Autograd engine now correctly homes meta tensors with CPU tensors (they have -1 device index so CUDA queues wouldn't work anyway) * `apply_`, `map_` and `map2_` are special cased to no-op on meta tensor self. These count as inplace operations too but they are implemented a little differently. Getting more meta function support triggers a number of bugs in the test suite, which I then fix: - Linear algebra functions sometimes don't report NotImplementedError because they get swallowed by catch all try blocks. This is tracked in https://github.com/pytorch/pytorch/issues/53739 - dlpack obviously doesn't work with meta tensors, I just disabled the test Signed-off-by: Edward Z. Yang Differential Revision: D27036572 Test Plan: Imported from OSS Reviewed By: agolynski, bdhirsh Pulled By: ezyang fbshipit-source-id: 7005ecf4feb92a643c37389fdfbd852dbf00ac78 --- .jenkins/pytorch/macos-test.sh | 2 + .jenkins/pytorch/win-test.sh | 1 + aten/src/ATen/TensorIterator.cpp | 4 - aten/src/ATen/native/Copy.cpp | 10 ++ aten/src/ATen/native/Distributions.cpp | 38 ++++++++ aten/src/ATen/native/Fill.cpp | 13 +++ aten/src/ATen/native/MetaTensor.cpp | 101 +++++++++----------- aten/src/ATen/native/README.md | 2 + aten/src/ATen/native/Resize.cpp | 20 +--- aten/src/ATen/native/native_functions.yaml | 17 +++- c10/core/DispatchKeySet.cpp | 1 + c10/core/StorageImpl.h | 23 +++++ c10/core/TensorOptions.h | 2 + test/test_linalg.py | 4 + test/test_torch.py | 3 +- tools/codegen/dest/register_dispatch_key.py | 3 +- torch/_tensor.py | 6 +- torch/csrc/DynamicTypes.cpp | 4 + torch/csrc/autograd/engine.cpp | 2 +- torch/csrc/utils/tensor_apply.cpp | 21 ++-- 20 files changed, 184 insertions(+), 93 deletions(-) diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 25434d10f57..159d04e5e44 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -4,6 +4,8 @@ # shellcheck source=./macos-common.sh source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" +export PYTORCH_TEST_SKIP_NOARCH=1 + conda install -y six pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 9ee1d4d6fcf..72ee7f19844 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -21,6 +21,7 @@ export TEST_DIR="${PWD}/test" export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}") export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results" export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") +export PYTORCH_TEST_SKIP_NOARCH=1 mkdir -p $TMP_DIR/build/torch diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 397bab00ec5..b1f43cf2e6b 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -940,10 +940,6 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config if (!config.check_mem_overlap_) { return; } - if (is_meta_) { - // We don't have pointer addresses, cannot check for overlap! - return; - } for (int i = 0; i < num_outputs_; i++) { const auto& output = operands_[i].tensor; if (!output.defined()) continue; diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 2720b03a31e..7d80fc036d2 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -160,6 +160,16 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return self; } + // Copies into meta self are OK and just ignored (similar to inplace) + if (self.is_meta()) { + // TODO: need to see if there is extra error checking needed + return self; + } + + if (src.is_meta()) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!") + } + // Re-dispatch copies when either src or self device not implemented here (e.g. XLA). // _copy_from has a proper device dispatch setup. // This includes: diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index d1ca12f1108..e37da717278 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -225,10 +225,21 @@ struct UniformStub { } }; +template +struct UniformMeta { + // No-op! + void operator()(TensorIterator& iter, double from, double to, c10::optional gen) { + } +}; + Tensor& uniform_(Tensor& self, double from, double to, c10::optional gen) { return at::native::templates::uniform_impl_(self, from, to, gen); } +Tensor& uniform_meta_(Tensor& self, double from, double to, c10::optional gen) { + return at::native::templates::uniform_impl_(self, from, to, gen); +} + // ==================================================== Normal ======================================================== template @@ -242,6 +253,11 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional return at::native::templates::normal_impl_(self, mean, std, gen); } +Tensor& normal_meta_(Tensor& self, double mean, double std, c10::optional gen) { + TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std); // TODO: dedupe + return self; +} + Tensor& normal_out(Tensor& output, const Tensor& mean, double std, c10::optional gen) { return at::native::templates::normal_out_impl(output, mean, std, gen); } @@ -289,6 +305,15 @@ struct RandomFromToStub { } }; +template +struct RandomFromToMeta { + // No-op! + void operator()(TensorIterator& iter, uint64_t range, int64_t from, c10::optional gen) { + } + void operator()(TensorIterator& iter, c10::optional gen) { + } +}; + Tensor& random_(Tensor& self, int64_t from, optional to, c10::optional gen) { return at::native::templates::random_from_to_impl(self, from, to, gen); } @@ -297,6 +322,19 @@ Tensor& random_(Tensor& self, int64_t to, c10::optional gen) { return random_(self, 0, to, gen); } +Tensor& random_meta_(Tensor& self, c10::optional gen) { + // No error checking yay + return self; +} + +Tensor& random_meta_(Tensor& self, int64_t from, optional to, c10::optional gen) { + return at::native::templates::random_from_to_impl(self, from, to, gen); +} + +Tensor& random_meta_(Tensor& self, int64_t to, c10::optional gen) { + return random_meta_(self, 0, to, gen); +} + // ==================================================================================================================== Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index af6e3a46acf..91f0534d5fa 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -41,6 +41,15 @@ Tensor& fill_(Tensor& self, const Tensor& value) { return fill_out(self, value.item()); } +Tensor& fill_meta_(Tensor& self, const Scalar& value) { + return self; +} + +Tensor& fill_meta_(Tensor& self, const Tensor& value) { + TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions."); + return self; +} + DEFINE_DISPATCH(fill_stub); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -114,5 +123,9 @@ Tensor& zero_(Tensor &self) { return self.fill_(0); } +Tensor& zero_meta_(Tensor& self) { + return self; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 63dc41971ea..0e8b947805a 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -1,81 +1,70 @@ #include #include +#include namespace at { namespace native { +// The meta allocator ignores whatever allocation is requested and always +// gives you nullptr +struct MetaAllocator final : public at::Allocator { + MetaAllocator() = default; + ~MetaAllocator() override = default; + static void deleter(void* const pointer) { + TORCH_INTERNAL_ASSERT(!pointer); + } + DataPtr allocate(const size_t nbytes) const override { + return {nullptr, nullptr, &deleter, at::Device(DeviceType::Meta)}; + } + DeleterFnPtr raw_deleter() const override { + return deleter; + } +}; + +static MetaAllocator g_meta_alloc; + +at::Allocator* GetMetaAllocator() { + return &g_meta_alloc; +} + Tensor empty_meta( IntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt ) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta); + auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta); // NB: because there is no SparseMeta (yet), non-strided layout is // exerciseable TORCH_CHECK_NOT_IMPLEMENTED( - layout_or_default(layout) == Layout::Strided, + layout_or_default(layout_opt) == Layout::Strided, "strided meta tensors not supported yet" ); - check_size_nonnegative(size); - - auto tensor = detail::make_tensor( - DispatchKeySet{DispatchKey::Meta}, - scalarTypeToTypeMeta(dtype_or_default(dtype)), - device - ); - - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); - - auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_); - - tensor.unsafeGetTensorImpl()->set_storage_access_should_throw(); - - return tensor; + auto* allocator = GetMetaAllocator(); + auto dtype = dtype_or_default(dtype_opt); + auto r = at::detail::empty_generic(size, allocator, at::DispatchKey::Meta, dtype, device, memory_format_opt); + return r; } Tensor empty_strided_meta( IntArrayRef size, IntArrayRef stride, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt ) { - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta); - // NB: because there is no SparseMeta (yet), non-strided layout is - // exerciseable - TORCH_CHECK_NOT_IMPLEMENTED( - layout_or_default(layout) == Layout::Strided, - "strided meta tensors not supported yet" - ); - - // NB: pin_memory intentionally ignored; it is a property of storage and - // therefore meta does not track it (this is not a forced choice, but it's - // the choice we made) - - check_size_nonnegative(size); - // TODO: check if strides are negative, - // https://github.com/pytorch/pytorch/issues/53391 - // (bugged here to be consistent with CPU implementation) - - auto tensor = detail::make_tensor( - DispatchKeySet{DispatchKey::Meta}, - scalarTypeToTypeMeta(dtype_or_default(dtype)), - device - ); - - tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride); - - tensor.unsafeGetTensorImpl()->set_storage_access_should_throw(); - - return tensor; + auto t = at::native::empty_meta({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); + // Amazingly the CPU implementation will work for us, because most of resize + // is generic except the memcpy, but the memcpy will be skipped if the source + // storage is nullptr (which it always is, for meta tensors) + at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride); + return t; } } // namespace native diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 0fc35b5ad37..54c79937903 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -274,6 +274,8 @@ dispatch: CompositeImplicitAutograd: func # overload is ignored, but out functions get suffixed with _out in their name +# (NB: no out functions in PyTorch today actually support autograd, but if they +# did, you could call them here and autograd would be inferred) func: func.out_overload(...) -> ... dispatch: CompositeImplicitAutograd: func_out diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index 67257697b9a..72818c697b3 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -102,13 +102,12 @@ Tensor& resize_as_( Tensor& resize_( Tensor& self, IntArrayRef size, - c10::optional optional_memory_format, - bool resize_storage) { + c10::optional optional_memory_format) { if (self.has_names()) { return resize_named_tensor_(self, size, optional_memory_format); } auto* self_ = self.unsafeGetTensorImpl(); - resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage); + resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt); if (optional_memory_format.has_value()) { auto memory_format = optional_memory_format.value(); @@ -121,20 +120,5 @@ Tensor& resize_( return self; } -Tensor& resize_( - Tensor& self, - IntArrayRef size, - c10::optional optional_memory_format) { - return resize_(self, size, optional_memory_format, /*resize_storage=*/true); -} - -Tensor& resize_meta_( - Tensor& self, - IntArrayRef size, - c10::optional optional_memory_format) { - // meta tensors don't have storage, so don't resize them - return resize_(self, size, optional_memory_format, /*resize_storage=*/false); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 30e49e36c3c..28a279fa65d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -558,7 +558,7 @@ - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) variants: function, method dispatch: - CPU, CUDA: as_strided_tensorimpl + CPU, CUDA, Meta: as_strided_tensorimpl QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl device_guard: False @@ -1522,10 +1522,9 @@ variants: method device_guard: False dispatch: - CPU: resize_ + CPU, Meta: resize_ CUDA: resize_cuda_ QuantizedCPU: quantized_resize_cpu_ - Meta: resize_meta_ - func: empty_quantized(int[] size, Tensor qtensor) -> Tensor variants: function @@ -1660,11 +1659,13 @@ variants: function, method dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_ + Meta: fill_meta_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) variants: function, method dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_ + Meta: fill_meta_ - func: floor(Tensor self) -> Tensor variants: function, method @@ -3998,6 +3999,7 @@ variants: method, function dispatch: CPU, CUDA: zero_ + Meta: zero_meta_ SparseCPU, SparseCUDA: zero_sparse_ MkldnnCPU: mkldnn_zero_ @@ -4699,7 +4701,7 @@ variants: method device_guard: False dispatch: - CPU, CUDA, QuantizedCPU, QuantizedCUDA: view + CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view # Warning: If you want to change the name or overload name of this @@ -5048,21 +5050,25 @@ variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: uniform_ + Meta: uniform_meta_ - func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -6235,6 +6241,7 @@ variants: method dispatch: CPU, CUDA: normal_ + Meta: normal_meta_ - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -8451,7 +8458,7 @@ python_module: special variants: function dispatch: - CompositeExplicitAutograd: special_entr + CPU, CUDA: special_entr - func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 2a29b967f37..e24de613ee5 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -16,6 +16,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKey::PrivateUse2, DispatchKey::PrivateUse3, DispatchKey::MLC, + DispatchKey::Meta, }); bool isBackendDispatchKey(DispatchKey t) { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index d6c0b7f8f4e..96897e57513 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -7,6 +7,29 @@ namespace c10 { +// A storage represents the underlying backing data buffer for a +// tensor. This concept was inherited from the original Torch7 +// codebase; we'd kind of like to get rid of the concept +// (see https://github.com/pytorch/pytorch/issues/14797) but +// it's hard work and no one has gotten around to doing it. +// +// NB: storage is supposed to uniquely own a data pointer; e.g., +// two non-null data pointers alias if and only if they are from +// the same storage. Technically you can violate this invariant +// (e.g., you can create a non-owning StorageImpl with at::from_blob) +// but a lot of things won't work correctly, including: +// +// - An ordinary deleter on such a storage is wrong, because normal deleters +// assume unique ownership, but if you have two storages at the same data, that +// implies there is some sort of shared ownership. So your deleter would have to +// actually be internally doing some sort of refcount thing +// - Deepcopy in Python side relies on storage equality and not data pointer +// equality; so if there are two separate storages pointing to the same data, +// the data will actually get duplicated in that case (one data ptr before, two +// data ptrs after) +// - Version counts won't work correctly, because we do all VC tracking at the +// level of storages (unless you explicitly disconnect the VC with detach); +// mutation because data pointers are the same are totally untracked struct C10_API StorageImpl final : public c10::intrusive_ptr_target { public: struct use_byte_size_t {}; diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 56d50df31f7..aeadb1fd3d8 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -701,6 +701,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { return DeviceType::XLA; case DispatchKey::Vulkan: return DeviceType::Vulkan; + case DispatchKey::Meta: + return DeviceType::Meta; // stuff that people are actively developing case DispatchKey::XPU: diff --git a/test/test_linalg.py b/test/test_linalg.py index e6d4f453f00..8d25f07a115 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1383,6 +1383,7 @@ class TestLinalg(TestCase): # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their matrix norm results match + @skipMeta # https://github.com/pytorch/pytorch/issues/54082 @skipCUDAIfNoMagma @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 2e-5}) @@ -1420,6 +1421,7 @@ class TestLinalg(TestCase): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) + @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -1474,6 +1476,7 @@ class TestLinalg(TestCase): actual = torch.linalg.cond(input, p) self.assertEqual(actual, expected) + @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -3369,6 +3372,7 @@ class TestLinalg(TestCase): a_inv = torch.linalg.tensorinv(a, ind=ind) self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind]) + @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) diff --git a/test/test_torch.py b/test/test_torch.py index 5ebcef914c2..17a0e93f156 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,7 +29,7 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, skipCUDAIfNoMagma, skipCUDAVersionIn, onlyCUDA, onlyCPU, - dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, + dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic) from typing import Dict, List @@ -5956,6 +5956,7 @@ class TestTorchDeviceType(TestCase): for x in xs: _test_helper(x, op, unary=True) + @skipMeta def test_dlpack_conversion(self, device): x = torch.randn(1, 2, 3, 4, device=device, dtype=torch.float) z = from_dlpack(to_dlpack(x)) diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 7a5de260e15..5fc1352177f 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -248,11 +248,12 @@ if (C10_UNLIKELY(current_device.has_value())) { if k is SchemaKind.functional: if self.dispatch_key == DispatchKey.Meta: + # TODO: dedupe this with below return """ if (strides.empty()) { outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); } else { - TORCH_INTERNAL_ASSERT(0, "not implemented yet"); + outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta)); } """ else: diff --git a/torch/_tensor.py b/torch/_tensor.py index 47894f05447..2e2d6e52217 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -57,7 +57,11 @@ class Tensor(torch._C._TensorBase): if id(self) in memo: return memo[id(self)] with torch.no_grad(): - if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc': + # TODO: skipping storage copy is wrong for meta, as meta + # does accurate alias tracking; however, the code below + # doesn't work because of + # https://github.com/pytorch/pytorch/issues/47442 + if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']: new_tensor = self.clone() else: new_storage = self.storage().__deepcopy__(memo) diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index 92e8a93c284..1d6373c008f 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -60,6 +60,10 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala PyTypeObject* getPyTypeObject( const at::Storage& storage, const caffe2::TypeMeta dtype) { + // TODO: https://github.com/pytorch/pytorch/issues/47442 + if (storage.device_type() == at::DeviceType::Meta) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported"); + } at::ScalarType scalarType = at::typeMetaToScalarType(dtype); auto attype = &at::getDeprecatedTypeProperties( at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())), diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 34492d60dcf..a1580156c50 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1072,7 +1072,7 @@ size_t Engine::ready_queue_size(const std::shared_ptr& graph_task, at // CPU ready queue is per GraphTask, but CUDA device ready queues are shared across all graph tasks auto Engine::ready_queue(std::shared_ptr cpu_ready_queue, at::Device device) -> std::shared_ptr{ - if (device.type() == at::kCPU) { + if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) { // return the cpu ready queue passed in TORCH_INTERNAL_ASSERT(cpu_ready_queue); return cpu_ready_queue; diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index 4b36484944d..00198892d6f 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -54,6 +54,9 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di } Tensor & apply_(Tensor & self, PyObject* fn) { + if (self.is_meta()) { + return self; // Just skip + } if (!self.device().is_cpu()) { throw TypeError("apply_ is only implemented on CPU tensors"); } @@ -63,13 +66,16 @@ Tensor & apply_(Tensor & self, PyObject* fn) { } Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { - if (!self.device().is_cpu()) { - throw TypeError("map_ is only implemented on CPU tensors"); - } if (!other_.options().type_equal(self.options())) { throw TypeError("map_: expected %s for 'other' (got %s)", self.toString().c_str(), other_.toString().c_str()); } + if (self.is_meta()) { + return self; // Just skip + } + if (!self.device().is_cpu()) { + throw TypeError("map_ is only implemented on CPU tensors"); + } Tensor other; std::tie(other) = expand_inplace(self, other_, "map_"); auto scalarType = self.scalar_type(); @@ -78,9 +84,6 @@ Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { } Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) { - if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { - throw TypeError("map2_ is only implemented on CPU tensors"); - } if (!x_.options().type_equal(self.options())) { throw TypeError("map2_: expected %s for argument 'x' (got %s)", self.toString().c_str(), x_.toString().c_str()); @@ -89,6 +92,12 @@ Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn throw TypeError("map2_: expected %s for argument 'y' (got %s)", self.toString().c_str(), y_.toString().c_str()); } + if (self.is_meta()) { + return self; // Just skip + } + if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { + throw TypeError("map2_ is only implemented on CPU tensors"); + } Tensor other1, other2; std::tie(other1, other2) = expand_inplace(self, x_, y_, "map2_"); auto scalarType = self.scalar_type();