diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 25434d10f57..159d04e5e44 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -4,6 +4,8 @@ # shellcheck source=./macos-common.sh source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" +export PYTORCH_TEST_SKIP_NOARCH=1 + conda install -y six pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 9ee1d4d6fcf..72ee7f19844 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -21,6 +21,7 @@ export TEST_DIR="${PWD}/test" export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}") export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results" export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") +export PYTORCH_TEST_SKIP_NOARCH=1 mkdir -p $TMP_DIR/build/torch diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 397bab00ec5..b1f43cf2e6b 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -940,10 +940,6 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config if (!config.check_mem_overlap_) { return; } - if (is_meta_) { - // We don't have pointer addresses, cannot check for overlap! - return; - } for (int i = 0; i < num_outputs_; i++) { const auto& output = operands_[i].tensor; if (!output.defined()) continue; diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 2720b03a31e..7d80fc036d2 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -160,6 +160,16 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return self; } + // Copies into meta self are OK and just ignored (similar to inplace) + if (self.is_meta()) { + // TODO: need to see if there is extra error checking needed + return self; + } + + if (src.is_meta()) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!") + } + // Re-dispatch copies when either src or self device not implemented here (e.g. XLA). // _copy_from has a proper device dispatch setup. // This includes: diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index d1ca12f1108..e37da717278 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -225,10 +225,21 @@ struct UniformStub { } }; +template +struct UniformMeta { + // No-op! + void operator()(TensorIterator& iter, double from, double to, c10::optional gen) { + } +}; + Tensor& uniform_(Tensor& self, double from, double to, c10::optional gen) { return at::native::templates::uniform_impl_(self, from, to, gen); } +Tensor& uniform_meta_(Tensor& self, double from, double to, c10::optional gen) { + return at::native::templates::uniform_impl_(self, from, to, gen); +} + // ==================================================== Normal ======================================================== template @@ -242,6 +253,11 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional return at::native::templates::normal_impl_(self, mean, std, gen); } +Tensor& normal_meta_(Tensor& self, double mean, double std, c10::optional gen) { + TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std); // TODO: dedupe + return self; +} + Tensor& normal_out(Tensor& output, const Tensor& mean, double std, c10::optional gen) { return at::native::templates::normal_out_impl(output, mean, std, gen); } @@ -289,6 +305,15 @@ struct RandomFromToStub { } }; +template +struct RandomFromToMeta { + // No-op! + void operator()(TensorIterator& iter, uint64_t range, int64_t from, c10::optional gen) { + } + void operator()(TensorIterator& iter, c10::optional gen) { + } +}; + Tensor& random_(Tensor& self, int64_t from, optional to, c10::optional gen) { return at::native::templates::random_from_to_impl(self, from, to, gen); } @@ -297,6 +322,19 @@ Tensor& random_(Tensor& self, int64_t to, c10::optional gen) { return random_(self, 0, to, gen); } +Tensor& random_meta_(Tensor& self, c10::optional gen) { + // No error checking yay + return self; +} + +Tensor& random_meta_(Tensor& self, int64_t from, optional to, c10::optional gen) { + return at::native::templates::random_from_to_impl(self, from, to, gen); +} + +Tensor& random_meta_(Tensor& self, int64_t to, c10::optional gen) { + return random_meta_(self, 0, to, gen); +} + // ==================================================================================================================== Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index af6e3a46acf..91f0534d5fa 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -41,6 +41,15 @@ Tensor& fill_(Tensor& self, const Tensor& value) { return fill_out(self, value.item()); } +Tensor& fill_meta_(Tensor& self, const Scalar& value) { + return self; +} + +Tensor& fill_meta_(Tensor& self, const Tensor& value) { + TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions."); + return self; +} + DEFINE_DISPATCH(fill_stub); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -114,5 +123,9 @@ Tensor& zero_(Tensor &self) { return self.fill_(0); } +Tensor& zero_meta_(Tensor& self) { + return self; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 63dc41971ea..0e8b947805a 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -1,81 +1,70 @@ #include #include +#include namespace at { namespace native { +// The meta allocator ignores whatever allocation is requested and always +// gives you nullptr +struct MetaAllocator final : public at::Allocator { + MetaAllocator() = default; + ~MetaAllocator() override = default; + static void deleter(void* const pointer) { + TORCH_INTERNAL_ASSERT(!pointer); + } + DataPtr allocate(const size_t nbytes) const override { + return {nullptr, nullptr, &deleter, at::Device(DeviceType::Meta)}; + } + DeleterFnPtr raw_deleter() const override { + return deleter; + } +}; + +static MetaAllocator g_meta_alloc; + +at::Allocator* GetMetaAllocator() { + return &g_meta_alloc; +} + Tensor empty_meta( IntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt ) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta); + auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta); // NB: because there is no SparseMeta (yet), non-strided layout is // exerciseable TORCH_CHECK_NOT_IMPLEMENTED( - layout_or_default(layout) == Layout::Strided, + layout_or_default(layout_opt) == Layout::Strided, "strided meta tensors not supported yet" ); - check_size_nonnegative(size); - - auto tensor = detail::make_tensor( - DispatchKeySet{DispatchKey::Meta}, - scalarTypeToTypeMeta(dtype_or_default(dtype)), - device - ); - - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); - - auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_); - - tensor.unsafeGetTensorImpl()->set_storage_access_should_throw(); - - return tensor; + auto* allocator = GetMetaAllocator(); + auto dtype = dtype_or_default(dtype_opt); + auto r = at::detail::empty_generic(size, allocator, at::DispatchKey::Meta, dtype, device, memory_format_opt); + return r; } Tensor empty_strided_meta( IntArrayRef size, IntArrayRef stride, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt ) { - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta); - // NB: because there is no SparseMeta (yet), non-strided layout is - // exerciseable - TORCH_CHECK_NOT_IMPLEMENTED( - layout_or_default(layout) == Layout::Strided, - "strided meta tensors not supported yet" - ); - - // NB: pin_memory intentionally ignored; it is a property of storage and - // therefore meta does not track it (this is not a forced choice, but it's - // the choice we made) - - check_size_nonnegative(size); - // TODO: check if strides are negative, - // https://github.com/pytorch/pytorch/issues/53391 - // (bugged here to be consistent with CPU implementation) - - auto tensor = detail::make_tensor( - DispatchKeySet{DispatchKey::Meta}, - scalarTypeToTypeMeta(dtype_or_default(dtype)), - device - ); - - tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride); - - tensor.unsafeGetTensorImpl()->set_storage_access_should_throw(); - - return tensor; + auto t = at::native::empty_meta({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); + // Amazingly the CPU implementation will work for us, because most of resize + // is generic except the memcpy, but the memcpy will be skipped if the source + // storage is nullptr (which it always is, for meta tensors) + at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride); + return t; } } // namespace native diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 0fc35b5ad37..54c79937903 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -274,6 +274,8 @@ dispatch: CompositeImplicitAutograd: func # overload is ignored, but out functions get suffixed with _out in their name +# (NB: no out functions in PyTorch today actually support autograd, but if they +# did, you could call them here and autograd would be inferred) func: func.out_overload(...) -> ... dispatch: CompositeImplicitAutograd: func_out diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index 67257697b9a..72818c697b3 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -102,13 +102,12 @@ Tensor& resize_as_( Tensor& resize_( Tensor& self, IntArrayRef size, - c10::optional optional_memory_format, - bool resize_storage) { + c10::optional optional_memory_format) { if (self.has_names()) { return resize_named_tensor_(self, size, optional_memory_format); } auto* self_ = self.unsafeGetTensorImpl(); - resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage); + resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt); if (optional_memory_format.has_value()) { auto memory_format = optional_memory_format.value(); @@ -121,20 +120,5 @@ Tensor& resize_( return self; } -Tensor& resize_( - Tensor& self, - IntArrayRef size, - c10::optional optional_memory_format) { - return resize_(self, size, optional_memory_format, /*resize_storage=*/true); -} - -Tensor& resize_meta_( - Tensor& self, - IntArrayRef size, - c10::optional optional_memory_format) { - // meta tensors don't have storage, so don't resize them - return resize_(self, size, optional_memory_format, /*resize_storage=*/false); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 30e49e36c3c..28a279fa65d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -558,7 +558,7 @@ - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) variants: function, method dispatch: - CPU, CUDA: as_strided_tensorimpl + CPU, CUDA, Meta: as_strided_tensorimpl QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl device_guard: False @@ -1522,10 +1522,9 @@ variants: method device_guard: False dispatch: - CPU: resize_ + CPU, Meta: resize_ CUDA: resize_cuda_ QuantizedCPU: quantized_resize_cpu_ - Meta: resize_meta_ - func: empty_quantized(int[] size, Tensor qtensor) -> Tensor variants: function @@ -1660,11 +1659,13 @@ variants: function, method dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_ + Meta: fill_meta_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) variants: function, method dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_ + Meta: fill_meta_ - func: floor(Tensor self) -> Tensor variants: function, method @@ -3998,6 +3999,7 @@ variants: method, function dispatch: CPU, CUDA: zero_ + Meta: zero_meta_ SparseCPU, SparseCUDA: zero_sparse_ MkldnnCPU: mkldnn_zero_ @@ -4699,7 +4701,7 @@ variants: method device_guard: False dispatch: - CPU, CUDA, QuantizedCPU, QuantizedCUDA: view + CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view # Warning: If you want to change the name or overload name of this @@ -5048,21 +5050,25 @@ variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: random_ + Meta: random_meta_ - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: uniform_ + Meta: uniform_meta_ - func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -6235,6 +6241,7 @@ variants: method dispatch: CPU, CUDA: normal_ + Meta: normal_meta_ - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -8451,7 +8458,7 @@ python_module: special variants: function dispatch: - CompositeExplicitAutograd: special_entr + CPU, CUDA: special_entr - func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 2a29b967f37..e24de613ee5 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -16,6 +16,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKey::PrivateUse2, DispatchKey::PrivateUse3, DispatchKey::MLC, + DispatchKey::Meta, }); bool isBackendDispatchKey(DispatchKey t) { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index d6c0b7f8f4e..96897e57513 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -7,6 +7,29 @@ namespace c10 { +// A storage represents the underlying backing data buffer for a +// tensor. This concept was inherited from the original Torch7 +// codebase; we'd kind of like to get rid of the concept +// (see https://github.com/pytorch/pytorch/issues/14797) but +// it's hard work and no one has gotten around to doing it. +// +// NB: storage is supposed to uniquely own a data pointer; e.g., +// two non-null data pointers alias if and only if they are from +// the same storage. Technically you can violate this invariant +// (e.g., you can create a non-owning StorageImpl with at::from_blob) +// but a lot of things won't work correctly, including: +// +// - An ordinary deleter on such a storage is wrong, because normal deleters +// assume unique ownership, but if you have two storages at the same data, that +// implies there is some sort of shared ownership. So your deleter would have to +// actually be internally doing some sort of refcount thing +// - Deepcopy in Python side relies on storage equality and not data pointer +// equality; so if there are two separate storages pointing to the same data, +// the data will actually get duplicated in that case (one data ptr before, two +// data ptrs after) +// - Version counts won't work correctly, because we do all VC tracking at the +// level of storages (unless you explicitly disconnect the VC with detach); +// mutation because data pointers are the same are totally untracked struct C10_API StorageImpl final : public c10::intrusive_ptr_target { public: struct use_byte_size_t {}; diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 56d50df31f7..aeadb1fd3d8 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -701,6 +701,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { return DeviceType::XLA; case DispatchKey::Vulkan: return DeviceType::Vulkan; + case DispatchKey::Meta: + return DeviceType::Meta; // stuff that people are actively developing case DispatchKey::XPU: diff --git a/test/test_linalg.py b/test/test_linalg.py index e6d4f453f00..8d25f07a115 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1383,6 +1383,7 @@ class TestLinalg(TestCase): # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their matrix norm results match + @skipMeta # https://github.com/pytorch/pytorch/issues/54082 @skipCUDAIfNoMagma @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 2e-5}) @@ -1420,6 +1421,7 @@ class TestLinalg(TestCase): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) + @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -1474,6 +1476,7 @@ class TestLinalg(TestCase): actual = torch.linalg.cond(input, p) self.assertEqual(actual, expected) + @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -3369,6 +3372,7 @@ class TestLinalg(TestCase): a_inv = torch.linalg.tensorinv(a, ind=ind) self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind]) + @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) diff --git a/test/test_torch.py b/test/test_torch.py index 5ebcef914c2..17a0e93f156 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,7 +29,7 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, skipCUDAIfNoMagma, skipCUDAVersionIn, onlyCUDA, onlyCPU, - dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, + dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic) from typing import Dict, List @@ -5956,6 +5956,7 @@ class TestTorchDeviceType(TestCase): for x in xs: _test_helper(x, op, unary=True) + @skipMeta def test_dlpack_conversion(self, device): x = torch.randn(1, 2, 3, 4, device=device, dtype=torch.float) z = from_dlpack(to_dlpack(x)) diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 7a5de260e15..5fc1352177f 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -248,11 +248,12 @@ if (C10_UNLIKELY(current_device.has_value())) { if k is SchemaKind.functional: if self.dispatch_key == DispatchKey.Meta: + # TODO: dedupe this with below return """ if (strides.empty()) { outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); } else { - TORCH_INTERNAL_ASSERT(0, "not implemented yet"); + outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta)); } """ else: diff --git a/torch/_tensor.py b/torch/_tensor.py index 47894f05447..2e2d6e52217 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -57,7 +57,11 @@ class Tensor(torch._C._TensorBase): if id(self) in memo: return memo[id(self)] with torch.no_grad(): - if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc': + # TODO: skipping storage copy is wrong for meta, as meta + # does accurate alias tracking; however, the code below + # doesn't work because of + # https://github.com/pytorch/pytorch/issues/47442 + if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']: new_tensor = self.clone() else: new_storage = self.storage().__deepcopy__(memo) diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index 92e8a93c284..1d6373c008f 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -60,6 +60,10 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala PyTypeObject* getPyTypeObject( const at::Storage& storage, const caffe2::TypeMeta dtype) { + // TODO: https://github.com/pytorch/pytorch/issues/47442 + if (storage.device_type() == at::DeviceType::Meta) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported"); + } at::ScalarType scalarType = at::typeMetaToScalarType(dtype); auto attype = &at::getDeprecatedTypeProperties( at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())), diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 34492d60dcf..a1580156c50 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1072,7 +1072,7 @@ size_t Engine::ready_queue_size(const std::shared_ptr& graph_task, at // CPU ready queue is per GraphTask, but CUDA device ready queues are shared across all graph tasks auto Engine::ready_queue(std::shared_ptr cpu_ready_queue, at::Device device) -> std::shared_ptr{ - if (device.type() == at::kCPU) { + if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) { // return the cpu ready queue passed in TORCH_INTERNAL_ASSERT(cpu_ready_queue); return cpu_ready_queue; diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index 4b36484944d..00198892d6f 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -54,6 +54,9 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di } Tensor & apply_(Tensor & self, PyObject* fn) { + if (self.is_meta()) { + return self; // Just skip + } if (!self.device().is_cpu()) { throw TypeError("apply_ is only implemented on CPU tensors"); } @@ -63,13 +66,16 @@ Tensor & apply_(Tensor & self, PyObject* fn) { } Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { - if (!self.device().is_cpu()) { - throw TypeError("map_ is only implemented on CPU tensors"); - } if (!other_.options().type_equal(self.options())) { throw TypeError("map_: expected %s for 'other' (got %s)", self.toString().c_str(), other_.toString().c_str()); } + if (self.is_meta()) { + return self; // Just skip + } + if (!self.device().is_cpu()) { + throw TypeError("map_ is only implemented on CPU tensors"); + } Tensor other; std::tie(other) = expand_inplace(self, other_, "map_"); auto scalarType = self.scalar_type(); @@ -78,9 +84,6 @@ Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { } Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) { - if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { - throw TypeError("map2_ is only implemented on CPU tensors"); - } if (!x_.options().type_equal(self.options())) { throw TypeError("map2_: expected %s for argument 'x' (got %s)", self.toString().c_str(), x_.toString().c_str()); @@ -89,6 +92,12 @@ Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn throw TypeError("map2_: expected %s for argument 'y' (got %s)", self.toString().c_str(), y_.toString().c_str()); } + if (self.is_meta()) { + return self; // Just skip + } + if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { + throw TypeError("map2_ is only implemented on CPU tensors"); + } Tensor other1, other2; std::tie(other1, other2) = expand_inplace(self, x_, y_, "map2_"); auto scalarType = self.scalar_type();