From 4926bd60040cb453aad726dc9b155743e149f11c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 8 Apr 2025 17:10:36 +0000 Subject: [PATCH] Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)" This reverts commit 3da14d38bd396f5bbe8494872d1509efa1a6f048. Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770)) --- ...cpp_extensions_open_device_registration.py | 50 +++++++++---------- torch/csrc/api/src/serialize.cpp | 1 + .../csrc/distributed/rpc/python_remote_call.h | 1 + torch/csrc/distributed/rpc/rref_proto.h | 1 + torch/csrc/distributed/rpc/script_call.h | 1 + .../csrc/distributed/rpc/script_remote_call.h | 1 + torch/csrc/distributed/rpc/script_resp.h | 1 + torch/csrc/jit/serialization/export.cpp | 1 - torch/csrc/jit/serialization/export.h | 1 + torch/csrc/jit/serialization/pickler.cpp | 20 -------- torch/csrc/jit/serialization/pickler.h | 19 +++++-- torch/csrc/jit/serialization/unpickler.cpp | 1 + 12 files changed, 49 insertions(+), 49 deletions(-) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 21394218c65..5d1f0c34ee2 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -4,6 +4,7 @@ import _codecs import io import os import sys +import tempfile import unittest from typing import Union from unittest.mock import patch @@ -345,22 +346,23 @@ class TestCppExtensionOpenRgistration(common.TestCase): cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg") self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg")) + @unittest.skip( + "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" + ) def test_open_device_serialization(self): self.module.set_custom_device_index(-1) storage = torch.UntypedStorage(4, device=torch.device("openreg")) - self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") + self.assertEqual(torch.serialization.location_tag(storage), "openreg") self.module.set_custom_device_index(0) storage = torch.UntypedStorage(4, device=torch.device("openreg")) self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - # TODO(FFFrog): Comment this because openreg.device is missing - # Uncomment this after improving openreg - # cpu_storage = torch.empty(4, 4).storage() - # openreg_storage = torch.serialization.default_restore_location( - # cpu_storage, "openreg:0" - # ) - # self.assertTrue(openreg_storage.is_openreg) + cpu_storage = torch.empty(4, 4).storage() + openreg_storage = torch.serialization.default_restore_location( + cpu_storage, "openreg:0" + ) + self.assertTrue(openreg_storage.is_openreg) # test tensor MetaData serialization x = torch.empty(4, 4).long() @@ -369,24 +371,22 @@ class TestCppExtensionOpenRgistration(common.TestCase): self.module.custom_set_backend_meta(y) self.assertTrue(self.module.check_backend_meta(y)) - # TODO(FFFrog): Comment this because openreg.device is missing - # Uncomment this after improving openreg - # self.module.custom_serialization_registry() - # with tempfile.TemporaryDirectory() as tmpdir: - # path = os.path.join(tmpdir, "data.pt") - # torch.save(y, path) - # z1 = torch.load(path) - # loads correctly onto the openreg backend device - # self.assertTrue(z1.is_openreg) - # loads BackendMeta data correctly - # self.assertTrue(self.module.check_backend_meta(z1)) + self.module.custom_serialization_registry() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.pt") + torch.save(y, path) + z1 = torch.load(path) + # loads correctly onto the openreg backend device + self.assertTrue(z1.is_openreg) + # loads BackendMeta data correctly + self.assertTrue(self.module.check_backend_meta(z1)) - # cross-backend - # z2 = torch.load(path, map_location="cpu") - # loads correctly onto the cpu backend device - # self.assertFalse(z2.is_openreg) - # loads BackendMeta data correctly - # self.assertFalse(self.module.check_backend_meta(z2)) + # cross-backend + z2 = torch.load(path, map_location="cpu") + # loads correctly onto the cpu backend device + self.assertFalse(z2.is_openreg) + # loads BackendMeta data correctly + self.assertFalse(self.module.check_backend_meta(z2)) def test_open_device_storage_resize(self): cpu_tensor = torch.randn([8]) diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index fae54d12484..e8497a7f22b 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h index 09d4ba36dc6..0a3054b594d 100644 --- a/torch/csrc/distributed/rpc/python_remote_call.h +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch::distributed::rpc { class TORCH_API PythonRemoteCall : public RpcCommandBase { diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index a1482b46939..e6bffd1870b 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 476ee118fe7..19e1871ead8 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index e18edab6482..534ac004459 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_resp.h b/torch/csrc/distributed/rpc/script_resp.h index 53841e3d705..fd8cd4b845d 100644 --- a/torch/csrc/distributed/rpc/script_resp.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -2,6 +2,7 @@ #include #include +#include namespace torch::distributed::rpc { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 9c10e94141a..ac20016c7bb 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 6f8e69bf0ca..8b2d6d84716 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 8038aa8ca65..6ce524293a7 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { return true; } -std::unordered_set& GetBackendMetaAllowlist() { - static std::unordered_set DeviceTypeAllowlist{ - c10::DeviceType::PrivateUse1}; - return DeviceTypeAllowlist; -} - -std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization() { - // The array to save function pointer for BackendMeta serialization. - // key is the DeviceType, value is std::pair obj. - // value.first represent get function and value.seconde represent set function - static std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES> - BackendMetaSerialization; - return BackendMetaSerialization; -} - } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 828f2b3b052..8accfa229b8 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -299,14 +299,27 @@ using BackendMetaPtr = std::function< void(const at::Tensor&, std::unordered_map&)>; // A allowlist of device type, currently available is PrivateUse1 -TORCH_API std::unordered_set& GetBackendMetaAllowlist(); +inline std::unordered_set& GetBackendMetaAllowlist() { + static std::unordered_set DeviceTypeAllowlist{ + c10::DeviceType::PrivateUse1}; + return DeviceTypeAllowlist; +} // Dynamically obtain serialization function pairs // that require the corresponding backend. -TORCH_API std::array< +inline std::array< std::optional>, at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization(); +GetBackendMetaSerialization() { + // The array to save function pointer for BackendMeta serialization. + // key is the DeviceType, value is std::pair obj. + // value.first represent get function and value.seconde represent set function + static std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES> + BackendMetaSerialization; + return BackendMetaSerialization; +} // Register function pointer of Tensor BackendMetadata for serialization. TORCH_API inline void TensorBackendMetaRegistry( diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index cdd58b8cef3..0cbb710f551 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,6 +5,7 @@ #endif #include #include +#include #include #include #include