Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)"

This reverts commit 3da14d38bd.

Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770))
This commit is contained in:
PyTorch MergeBot 2025-04-08 17:10:36 +00:00
parent 3e0038ae85
commit 4926bd6004
12 changed files with 49 additions and 49 deletions

View File

@ -4,6 +4,7 @@ import _codecs
import io import io
import os import os
import sys import sys
import tempfile
import unittest import unittest
from typing import Union from typing import Union
from unittest.mock import patch from unittest.mock import patch
@ -345,22 +346,23 @@ class TestCppExtensionOpenRgistration(common.TestCase):
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg") cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg")) self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
@unittest.skip(
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
)
def test_open_device_serialization(self): def test_open_device_serialization(self):
self.module.set_custom_device_index(-1) self.module.set_custom_device_index(-1)
storage = torch.UntypedStorage(4, device=torch.device("openreg")) storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") self.assertEqual(torch.serialization.location_tag(storage), "openreg")
self.module.set_custom_device_index(0) self.module.set_custom_device_index(0)
storage = torch.UntypedStorage(4, device=torch.device("openreg")) storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
# TODO(FFFrog): Comment this because openreg.device is missing cpu_storage = torch.empty(4, 4).storage()
# Uncomment this after improving openreg openreg_storage = torch.serialization.default_restore_location(
# cpu_storage = torch.empty(4, 4).storage() cpu_storage, "openreg:0"
# openreg_storage = torch.serialization.default_restore_location( )
# cpu_storage, "openreg:0" self.assertTrue(openreg_storage.is_openreg)
# )
# self.assertTrue(openreg_storage.is_openreg)
# test tensor MetaData serialization # test tensor MetaData serialization
x = torch.empty(4, 4).long() x = torch.empty(4, 4).long()
@ -369,24 +371,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.module.custom_set_backend_meta(y) self.module.custom_set_backend_meta(y)
self.assertTrue(self.module.check_backend_meta(y)) self.assertTrue(self.module.check_backend_meta(y))
# TODO(FFFrog): Comment this because openreg.device is missing self.module.custom_serialization_registry()
# Uncomment this after improving openreg with tempfile.TemporaryDirectory() as tmpdir:
# self.module.custom_serialization_registry() path = os.path.join(tmpdir, "data.pt")
# with tempfile.TemporaryDirectory() as tmpdir: torch.save(y, path)
# path = os.path.join(tmpdir, "data.pt") z1 = torch.load(path)
# torch.save(y, path)
# z1 = torch.load(path)
# loads correctly onto the openreg backend device # loads correctly onto the openreg backend device
# self.assertTrue(z1.is_openreg) self.assertTrue(z1.is_openreg)
# loads BackendMeta data correctly # loads BackendMeta data correctly
# self.assertTrue(self.module.check_backend_meta(z1)) self.assertTrue(self.module.check_backend_meta(z1))
# cross-backend # cross-backend
# z2 = torch.load(path, map_location="cpu") z2 = torch.load(path, map_location="cpu")
# loads correctly onto the cpu backend device # loads correctly onto the cpu backend device
# self.assertFalse(z2.is_openreg) self.assertFalse(z2.is_openreg)
# loads BackendMeta data correctly # loads BackendMeta data correctly
# self.assertFalse(self.module.check_backend_meta(z2)) self.assertFalse(self.module.check_backend_meta(z2))
def test_open_device_storage_resize(self): def test_open_device_storage_resize(self):
cpu_tensor = torch.randn([8]) cpu_tensor = torch.randn([8])

View File

@ -1,4 +1,5 @@
#include <torch/csrc/jit/serialization/pickle.h> #include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/serialize.h> #include <torch/serialize.h>
#include <vector> #include <vector>

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {
class TORCH_API PythonRemoteCall : public RpcCommandBase { class TORCH_API PythonRemoteCall : public RpcCommandBase {

View File

@ -4,6 +4,7 @@
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector> #include <vector>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <optional> #include <optional>
#include <vector> #include <vector>

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/script_call.h> #include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector> #include <vector>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -2,6 +2,7 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -16,7 +16,6 @@
#include <torch/csrc/jit/serialization/import_export_functions.h> #include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h> #include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/jit/serialization/onnx.h> #include <torch/csrc/jit/serialization/onnx.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/onnx/back_compat.h> #include <torch/csrc/onnx/back_compat.h>
#include <torch/csrc/onnx/onnx.h> #include <torch/csrc/onnx/onnx.h>
#include <torch/version.h> #include <torch/version.h>

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/serialization/export_bytecode.h> #include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h> #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/python_print.h> #include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/serialization/storage_context.h> #include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h> #include <torch/csrc/jit/serialization/type_name_uniquer.h>

View File

@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
return true; return true;
} }
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
} // namespace torch::jit } // namespace torch::jit

View File

@ -299,14 +299,27 @@ using BackendMetaPtr = std::function<
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>; void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
// A allowlist of device type, currently available is PrivateUse1 // A allowlist of device type, currently available is PrivateUse1
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist(); inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
// Dynamically obtain serialization function pairs // Dynamically obtain serialization function pairs
// that require the corresponding backend. // that require the corresponding backend.
TORCH_API std::array< inline std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>, std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>& at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization(); GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
// Register function pointer of Tensor BackendMetadata for serialization. // Register function pointer of Tensor BackendMetadata for serialization.
TORCH_API inline void TensorBackendMetaRegistry( TORCH_API inline void TensorBackendMetaRegistry(

View File

@ -5,6 +5,7 @@
#endif #endif
#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/mobile/type_parser.h> #include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h> #include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h> #include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/csrc/utils/byte_order.h> #include <torch/csrc/utils/byte_order.h>