mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)"
This reverts commit 3da14d38bd.
Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770))
This commit is contained in:
parent
3e0038ae85
commit
4926bd6004
|
|
@ -4,6 +4,7 @@ import _codecs
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
@ -345,22 +346,23 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||||
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
|
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
|
||||||
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
|
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
|
||||||
|
)
|
||||||
def test_open_device_serialization(self):
|
def test_open_device_serialization(self):
|
||||||
self.module.set_custom_device_index(-1)
|
self.module.set_custom_device_index(-1)
|
||||||
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
||||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
self.assertEqual(torch.serialization.location_tag(storage), "openreg")
|
||||||
|
|
||||||
self.module.set_custom_device_index(0)
|
self.module.set_custom_device_index(0)
|
||||||
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
||||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||||
|
|
||||||
# TODO(FFFrog): Comment this because openreg.device is missing
|
cpu_storage = torch.empty(4, 4).storage()
|
||||||
# Uncomment this after improving openreg
|
openreg_storage = torch.serialization.default_restore_location(
|
||||||
# cpu_storage = torch.empty(4, 4).storage()
|
cpu_storage, "openreg:0"
|
||||||
# openreg_storage = torch.serialization.default_restore_location(
|
)
|
||||||
# cpu_storage, "openreg:0"
|
self.assertTrue(openreg_storage.is_openreg)
|
||||||
# )
|
|
||||||
# self.assertTrue(openreg_storage.is_openreg)
|
|
||||||
|
|
||||||
# test tensor MetaData serialization
|
# test tensor MetaData serialization
|
||||||
x = torch.empty(4, 4).long()
|
x = torch.empty(4, 4).long()
|
||||||
|
|
@ -369,24 +371,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||||
self.module.custom_set_backend_meta(y)
|
self.module.custom_set_backend_meta(y)
|
||||||
self.assertTrue(self.module.check_backend_meta(y))
|
self.assertTrue(self.module.check_backend_meta(y))
|
||||||
|
|
||||||
# TODO(FFFrog): Comment this because openreg.device is missing
|
self.module.custom_serialization_registry()
|
||||||
# Uncomment this after improving openreg
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# self.module.custom_serialization_registry()
|
path = os.path.join(tmpdir, "data.pt")
|
||||||
# with tempfile.TemporaryDirectory() as tmpdir:
|
torch.save(y, path)
|
||||||
# path = os.path.join(tmpdir, "data.pt")
|
z1 = torch.load(path)
|
||||||
# torch.save(y, path)
|
|
||||||
# z1 = torch.load(path)
|
|
||||||
# loads correctly onto the openreg backend device
|
# loads correctly onto the openreg backend device
|
||||||
# self.assertTrue(z1.is_openreg)
|
self.assertTrue(z1.is_openreg)
|
||||||
# loads BackendMeta data correctly
|
# loads BackendMeta data correctly
|
||||||
# self.assertTrue(self.module.check_backend_meta(z1))
|
self.assertTrue(self.module.check_backend_meta(z1))
|
||||||
|
|
||||||
# cross-backend
|
# cross-backend
|
||||||
# z2 = torch.load(path, map_location="cpu")
|
z2 = torch.load(path, map_location="cpu")
|
||||||
# loads correctly onto the cpu backend device
|
# loads correctly onto the cpu backend device
|
||||||
# self.assertFalse(z2.is_openreg)
|
self.assertFalse(z2.is_openreg)
|
||||||
# loads BackendMeta data correctly
|
# loads BackendMeta data correctly
|
||||||
# self.assertFalse(self.module.check_backend_meta(z2))
|
self.assertFalse(self.module.check_backend_meta(z2))
|
||||||
|
|
||||||
def test_open_device_storage_resize(self):
|
def test_open_device_storage_resize(self):
|
||||||
cpu_tensor = torch.randn([8])
|
cpu_tensor = torch.randn([8])
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#include <torch/csrc/jit/serialization/pickle.h>
|
#include <torch/csrc/jit/serialization/pickle.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/serialize.h>
|
#include <torch/serialize.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
||||||
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||||
#include <torch/csrc/jit/serialization/onnx.h>
|
#include <torch/csrc/jit/serialization/onnx.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <torch/csrc/onnx/back_compat.h>
|
#include <torch/csrc/onnx/back_compat.h>
|
||||||
#include <torch/csrc/onnx/onnx.h>
|
#include <torch/csrc/onnx/onnx.h>
|
||||||
#include <torch/version.h>
|
#include <torch/version.h>
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/csrc/jit/serialization/python_print.h>
|
#include <torch/csrc/jit/serialization/python_print.h>
|
||||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||||
|
|
|
||||||
|
|
@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
|
||||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
|
||||||
c10::DeviceType::PrivateUse1};
|
|
||||||
return DeviceTypeAllowlist;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::array<
|
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
|
||||||
GetBackendMetaSerialization() {
|
|
||||||
// The array to save function pointer for BackendMeta serialization.
|
|
||||||
// key is the DeviceType, value is std::pair obj.
|
|
||||||
// value.first represent get function and value.seconde represent set function
|
|
||||||
static std::array<
|
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
|
||||||
BackendMetaSerialization;
|
|
||||||
return BackendMetaSerialization;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch::jit
|
} // namespace torch::jit
|
||||||
|
|
|
||||||
|
|
@ -299,14 +299,27 @@ using BackendMetaPtr = std::function<
|
||||||
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
||||||
|
|
||||||
// A allowlist of device type, currently available is PrivateUse1
|
// A allowlist of device type, currently available is PrivateUse1
|
||||||
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
|
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||||
|
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||||
|
c10::DeviceType::PrivateUse1};
|
||||||
|
return DeviceTypeAllowlist;
|
||||||
|
}
|
||||||
|
|
||||||
// Dynamically obtain serialization function pairs
|
// Dynamically obtain serialization function pairs
|
||||||
// that require the corresponding backend.
|
// that require the corresponding backend.
|
||||||
TORCH_API std::array<
|
inline std::array<
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||||
GetBackendMetaSerialization();
|
GetBackendMetaSerialization() {
|
||||||
|
// The array to save function pointer for BackendMeta serialization.
|
||||||
|
// key is the DeviceType, value is std::pair obj.
|
||||||
|
// value.first represent get function and value.seconde represent set function
|
||||||
|
static std::array<
|
||||||
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||||
|
BackendMetaSerialization;
|
||||||
|
return BackendMetaSerialization;
|
||||||
|
}
|
||||||
|
|
||||||
// Register function pointer of Tensor BackendMetadata for serialization.
|
// Register function pointer of Tensor BackendMetadata for serialization.
|
||||||
TORCH_API inline void TensorBackendMetaRegistry(
|
TORCH_API inline void TensorBackendMetaRegistry(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||||
#include <torch/csrc/utils/byte_order.h>
|
#include <torch/csrc/utils/byte_order.h>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user