mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[schema_upgrader] add C++ upgrader for json based upgrading (#156761)
Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi
This commit is contained in:
parent
064a7db7fc
commit
aeffb68d34
|
|
@ -895,6 +895,8 @@ libtorch_python_core_sources = [
|
||||||
"torch/csrc/mps/Module.cpp",
|
"torch/csrc/mps/Module.cpp",
|
||||||
"torch/csrc/mtia/Module.cpp",
|
"torch/csrc/mtia/Module.cpp",
|
||||||
"torch/csrc/export/pybind.cpp",
|
"torch/csrc/export/pybind.cpp",
|
||||||
|
"torch/csrc/export/upgrader.cpp",
|
||||||
|
"torch/csrc/export/example_upgraders.cpp",
|
||||||
"torch/csrc/inductor/aoti_package/pybind.cpp",
|
"torch/csrc/inductor/aoti_package/pybind.cpp",
|
||||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||||
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
||||||
|
|
|
||||||
284
test/export/test_upgrader.py
Normal file
284
test/export/test_upgrader.py
Normal file
|
|
@ -0,0 +1,284 @@
|
||||||
|
# Owner(s): ["oncall: export"]
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpgrader(TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# Register example upgraders dynamically
|
||||||
|
torch._C._export.register_example_upgraders()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
# Clean up registered upgraders
|
||||||
|
torch._C._export.deregister_example_upgraders()
|
||||||
|
|
||||||
|
def test_nn_module_stack_transformation_from_v0(self):
|
||||||
|
"""Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0"""
|
||||||
|
|
||||||
|
# Create a mock JSON object that simulates version 0 schema
|
||||||
|
# with nn_module_stack as a string that needs to be upgraded
|
||||||
|
mock_json = {
|
||||||
|
"schema_version": {"major": 0, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"target": "aten.add.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"nn_module_stack": "original_stack_info",
|
||||||
|
"other_field": "some_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"target": "aten.mul.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"nn_module_stack": "another_stack",
|
||||||
|
"stack_trace": "some trace",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test the upgrader using the Python binding
|
||||||
|
serialized_json = json.dumps(mock_json)
|
||||||
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
upgraded_json = json.loads(upgraded_json_str)
|
||||||
|
|
||||||
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||||
|
|
||||||
|
# Verify nn_module_stack was prepended with "test_upgrader_"
|
||||||
|
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
||||||
|
|
||||||
|
# Check first node
|
||||||
|
first_node_metadata = nodes[0]["metadata"]
|
||||||
|
nn_stack = first_node_metadata["nn_module_stack"]
|
||||||
|
self.assertIsInstance(nn_stack, str)
|
||||||
|
self.assertEqual(nn_stack, "test_upgrader_original_stack_info")
|
||||||
|
# Other metadata should be unchanged
|
||||||
|
self.assertEqual(first_node_metadata["other_field"], "some_value")
|
||||||
|
|
||||||
|
# Check second node
|
||||||
|
second_node_metadata = nodes[1]["metadata"]
|
||||||
|
nn_stack2 = second_node_metadata["nn_module_stack"]
|
||||||
|
self.assertIsInstance(nn_stack2, str)
|
||||||
|
self.assertEqual(nn_stack2, "test_upgrader_another_stack")
|
||||||
|
# Other metadata should be unchanged
|
||||||
|
self.assertEqual(second_node_metadata["stack_trace"], "some trace")
|
||||||
|
|
||||||
|
def test_nn_module_stack_error_handling_invalid_type(self):
|
||||||
|
"""Test error handling when nn_module_stack is not a string"""
|
||||||
|
|
||||||
|
# Test case: nn_module_stack is not a string
|
||||||
|
mock_json_invalid_type = {
|
||||||
|
"schema_version": {"major": 0, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"target": "aten.add.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"nn_module_stack": 42 # Invalid: should be string
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Error in upgrader 'version_0_upgrader_registered'",
|
||||||
|
):
|
||||||
|
serialized_json = json.dumps(mock_json_invalid_type)
|
||||||
|
torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
|
||||||
|
def test_nodes_without_metadata_handled_gracefully(self):
|
||||||
|
"""Test that nodes without metadata or nn_module_stack are handled gracefully"""
|
||||||
|
|
||||||
|
mock_json = {
|
||||||
|
"schema_version": {"major": 0, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"target": "aten.add.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
# No metadata field
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"target": "aten.mul.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"stack_trace": "some trace"
|
||||||
|
# No nn_module_stack field
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
serialized_json = json.dumps(mock_json)
|
||||||
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
upgraded_json = json.loads(upgraded_json_str)
|
||||||
|
|
||||||
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||||
|
|
||||||
|
# Verify nodes are unchanged
|
||||||
|
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
||||||
|
self.assertEqual(len(nodes), 2)
|
||||||
|
|
||||||
|
# First node should have no metadata
|
||||||
|
self.assertNotIn("metadata", nodes[0])
|
||||||
|
|
||||||
|
# Second node should have unchanged metadata
|
||||||
|
self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace")
|
||||||
|
self.assertNotIn("nn_module_stack", nodes[1]["metadata"])
|
||||||
|
|
||||||
|
def test_field_renaming_chain_from_v0_complete(self):
|
||||||
|
"""Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2"""
|
||||||
|
|
||||||
|
mock_json = {
|
||||||
|
"schema_version": {"major": 0, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"target": "aten.add.Tensor",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {"nn_module_stack": "test_stack"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"old_test_field": "original_value",
|
||||||
|
"existing_field": "existing_value",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test the upgrader using the Python binding
|
||||||
|
serialized_json = json.dumps(mock_json)
|
||||||
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
upgraded_json = json.loads(upgraded_json_str)
|
||||||
|
|
||||||
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||||
|
|
||||||
|
# Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2
|
||||||
|
graph = upgraded_json["graph_module"]["graph"]
|
||||||
|
self.assertIn("new_test_field2", graph)
|
||||||
|
self.assertEqual(graph["new_test_field2"], "original_value")
|
||||||
|
self.assertNotIn("old_test_field", graph)
|
||||||
|
self.assertNotIn("new_test_field", graph)
|
||||||
|
|
||||||
|
# Verify existing fields are preserved
|
||||||
|
self.assertEqual(graph["existing_field"], "existing_value")
|
||||||
|
self.assertIn("inputs", graph)
|
||||||
|
self.assertIn("outputs", graph)
|
||||||
|
self.assertIn("nodes", graph)
|
||||||
|
|
||||||
|
# Verify the nn_module_stack was also upgraded by the other upgrader
|
||||||
|
nodes = graph["nodes"]
|
||||||
|
self.assertEqual(
|
||||||
|
nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_field_renaming_chain_from_v0_missing_field(self):
|
||||||
|
"""Test that upgraders work gracefully when old_test_field doesn't exist"""
|
||||||
|
|
||||||
|
mock_json = {
|
||||||
|
"schema_version": {"major": 0, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"nodes": [],
|
||||||
|
"existing_field": "existing_value",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test the upgrader using the Python binding
|
||||||
|
serialized_json = json.dumps(mock_json)
|
||||||
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
upgraded_json = json.loads(upgraded_json_str)
|
||||||
|
|
||||||
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||||
|
|
||||||
|
# Verify no field transformations occurred since old_test_field didn't exist
|
||||||
|
graph = upgraded_json["graph_module"]["graph"]
|
||||||
|
self.assertNotIn("new_test_field2", graph)
|
||||||
|
self.assertNotIn("new_test_field", graph)
|
||||||
|
self.assertNotIn("old_test_field", graph)
|
||||||
|
|
||||||
|
# Verify existing fields are preserved
|
||||||
|
self.assertEqual(graph["existing_field"], "existing_value")
|
||||||
|
self.assertIn("inputs", graph)
|
||||||
|
self.assertIn("outputs", graph)
|
||||||
|
self.assertIn("nodes", graph)
|
||||||
|
|
||||||
|
def test_field_renaming_from_v1_partial_chain(self):
|
||||||
|
"""Test partial upgrade chain starting from v1: new_test_field -> new_test_field2"""
|
||||||
|
|
||||||
|
mock_json = {
|
||||||
|
"schema_version": {"major": 1, "minor": 0},
|
||||||
|
"graph_module": {
|
||||||
|
"graph": {
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [],
|
||||||
|
"nodes": [],
|
||||||
|
"new_test_field": "test_value",
|
||||||
|
"existing_field": "existing_value",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test the upgrader using the Python binding
|
||||||
|
serialized_json = json.dumps(mock_json)
|
||||||
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||||
|
upgraded_json = json.loads(upgraded_json_str)
|
||||||
|
|
||||||
|
# Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||||
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||||
|
|
||||||
|
# Verify new_test_field was renamed to new_test_field2
|
||||||
|
graph = upgraded_json["graph_module"]["graph"]
|
||||||
|
self.assertIn("new_test_field2", graph)
|
||||||
|
self.assertEqual(graph["new_test_field2"], "test_value")
|
||||||
|
self.assertNotIn("new_test_field", graph)
|
||||||
|
|
||||||
|
# Verify existing fields are preserved
|
||||||
|
self.assertEqual(graph["existing_field"], "existing_value")
|
||||||
|
self.assertIn("inputs", graph)
|
||||||
|
self.assertIn("outputs", graph)
|
||||||
|
self.assertIn("nodes", graph)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
||||||
89
torch/csrc/export/example_upgraders.cpp
Normal file
89
torch/csrc/export/example_upgraders.cpp
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
#include <torch/csrc/export/example_upgraders.h>
|
||||||
|
#include <torch/csrc/export/upgrader.h>
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
/// Register test upgraders for the upgrader system.
|
||||||
|
/// and shows some common upgrade patterns.
|
||||||
|
static bool test_upgraders_registered = false;
|
||||||
|
|
||||||
|
void registerExampleUpgraders() {
|
||||||
|
if (test_upgraders_registered) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
registerUpgrader(
|
||||||
|
0,
|
||||||
|
"graph_module.graph.nodes",
|
||||||
|
[](const nlohmann::json& nodes_array) -> nlohmann::json {
|
||||||
|
nlohmann::json upgraded_nodes = nodes_array;
|
||||||
|
|
||||||
|
// Process each node in the nodes array
|
||||||
|
for (auto& node : upgraded_nodes) {
|
||||||
|
if (node.contains("metadata") && node["metadata"].is_object()) {
|
||||||
|
// Process each metadata key-value pair
|
||||||
|
for (auto& [key, value] : node["metadata"].items()) {
|
||||||
|
if (key == "nn_module_stack") {
|
||||||
|
// Transform nn_module_stack values by prepending prefix
|
||||||
|
if (value.is_string()) {
|
||||||
|
std::string stack_str = value.get<std::string>();
|
||||||
|
value = "test_upgrader_" + stack_str;
|
||||||
|
} else {
|
||||||
|
throwUpgraderError(
|
||||||
|
"version_0_upgrader_registered",
|
||||||
|
0,
|
||||||
|
"nn_module_stack metadata value must be a string, got: " +
|
||||||
|
std::string(value.type_name()),
|
||||||
|
node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Other metadata keys remain unchanged
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return upgraded_nodes;
|
||||||
|
});
|
||||||
|
|
||||||
|
registerUpgrader(
|
||||||
|
0,
|
||||||
|
"graph_module.graph",
|
||||||
|
[](const nlohmann::json& graph_obj) -> nlohmann::json {
|
||||||
|
nlohmann::json upgraded_graph = graph_obj;
|
||||||
|
|
||||||
|
// Rename field if it exists in the graph object
|
||||||
|
if (upgraded_graph.contains("old_test_field")) {
|
||||||
|
upgraded_graph["new_test_field"] = upgraded_graph["old_test_field"];
|
||||||
|
upgraded_graph.erase("old_test_field");
|
||||||
|
}
|
||||||
|
|
||||||
|
return upgraded_graph;
|
||||||
|
});
|
||||||
|
|
||||||
|
registerUpgrader(
|
||||||
|
1,
|
||||||
|
std::vector<std::string>{"graph_module", "graph"},
|
||||||
|
[](const nlohmann::json& graph_obj) -> nlohmann::json {
|
||||||
|
nlohmann::json upgraded_graph = graph_obj;
|
||||||
|
|
||||||
|
// Continue the field renaming chain from version 0
|
||||||
|
if (upgraded_graph.contains("new_test_field")) {
|
||||||
|
upgraded_graph["new_test_field2"] = upgraded_graph["new_test_field"];
|
||||||
|
upgraded_graph.erase("new_test_field");
|
||||||
|
}
|
||||||
|
|
||||||
|
return upgraded_graph;
|
||||||
|
});
|
||||||
|
|
||||||
|
test_upgraders_registered = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deregister test upgraders for the upgrader system.
|
||||||
|
void deregisterExampleUpgraders() {
|
||||||
|
deregisterUpgrader(0, "graph_module.graph.nodes");
|
||||||
|
deregisterUpgrader(0, "graph_module.graph");
|
||||||
|
deregisterUpgrader(1, std::vector<std::string>{"graph_module", "graph"});
|
||||||
|
test_upgraders_registered = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::_export
|
||||||
15
torch/csrc/export/example_upgraders.h
Normal file
15
torch/csrc/export/example_upgraders.h
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
/// Register example upgraders for the upgrader system for testing.
|
||||||
|
/// This function demonstrates common upgrade patterns and is primarily
|
||||||
|
/// used for testing and demonstration purposes.
|
||||||
|
void registerExampleUpgraders();
|
||||||
|
|
||||||
|
/// Deregister example upgraders for the upgrader system for testing.
|
||||||
|
/// This function cleans up the example upgraders that were registered
|
||||||
|
/// by registerExampleUpgraders().
|
||||||
|
void deregisterExampleUpgraders();
|
||||||
|
|
||||||
|
} // namespace torch::_export
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
#include <torch/csrc/export/example_upgraders.h>
|
||||||
#include <torch/csrc/export/pt2_archive_constants.h>
|
#include <torch/csrc/export/pt2_archive_constants.h>
|
||||||
#include <torch/csrc/export/pybind.h>
|
#include <torch/csrc/export/pybind.h>
|
||||||
|
#include <torch/csrc/export/upgrader.h>
|
||||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
|
@ -15,13 +17,37 @@ void initExportBindings(PyObject* module) {
|
||||||
|
|
||||||
exportModule.def(
|
exportModule.def(
|
||||||
"deserialize_exported_program", [](const std::string& serialized) {
|
"deserialize_exported_program", [](const std::string& serialized) {
|
||||||
return nlohmann::json::parse(serialized).get<ExportedProgram>();
|
auto parsed = nlohmann::json::parse(serialized);
|
||||||
|
|
||||||
|
// Query the current Python schema version as target
|
||||||
|
// TODO: expose schema_version in gneerated_serialization_types.h and
|
||||||
|
// access it here directly.
|
||||||
|
py::module_ schema_module =
|
||||||
|
py::module_::import("torch._export.serde.schema");
|
||||||
|
py::tuple schema_version_tuple = schema_module.attr("SCHEMA_VERSION");
|
||||||
|
int target_version = schema_version_tuple[0].cast<int>();
|
||||||
|
|
||||||
|
auto upgraded = upgrade(parsed, target_version);
|
||||||
|
return upgraded.get<ExportedProgram>();
|
||||||
});
|
});
|
||||||
|
|
||||||
exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) {
|
exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) {
|
||||||
return nlohmann::json(ep).dump();
|
return nlohmann::json(ep).dump();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
exportModule.def(
|
||||||
|
"upgrade", [](const std::string& serialized_json, int target_version) {
|
||||||
|
auto parsed = nlohmann::json::parse(serialized_json);
|
||||||
|
auto upgraded = upgrade(parsed, target_version);
|
||||||
|
return upgraded.dump();
|
||||||
|
});
|
||||||
|
|
||||||
|
exportModule.def(
|
||||||
|
"register_example_upgraders", []() { registerExampleUpgraders(); });
|
||||||
|
|
||||||
|
exportModule.def(
|
||||||
|
"deregister_example_upgraders", []() { deregisterExampleUpgraders(); });
|
||||||
|
|
||||||
for (const auto& entry : torch::_export::archive_spec::kAllConstants) {
|
for (const auto& entry : torch::_export::archive_spec::kAllConstants) {
|
||||||
pt2ArchiveModule.attr(entry.first) = entry.second;
|
pt2ArchiveModule.attr(entry.first) = entry.second;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
242
torch/csrc/export/upgrader.cpp
Normal file
242
torch/csrc/export/upgrader.cpp
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
#include <torch/csrc/export/upgrader.h>
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
// Global upgrader registry organized by version.
|
||||||
|
// Using std::multiset to maintain automatic bottom-up ordering where
|
||||||
|
// deeper keypaths are processed before shallower ones.
|
||||||
|
static std::map<int, std::multiset<Upgrader>> upgrader_registry;
|
||||||
|
|
||||||
|
static const std::multiset<Upgrader>& getUpgrader(int current_version) {
|
||||||
|
static const std::multiset<Upgrader> empty_upgraders;
|
||||||
|
auto it = upgrader_registry.find(current_version);
|
||||||
|
if (it != upgrader_registry.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return empty_upgraders;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nlohmann::json getFieldByKeypath(
|
||||||
|
const nlohmann::json& obj,
|
||||||
|
const std::vector<std::string>& keypath) {
|
||||||
|
nlohmann::json current = obj;
|
||||||
|
for (const auto& key : keypath) {
|
||||||
|
if (!current.contains(key)) {
|
||||||
|
throw std::runtime_error("Keypath not found: " + key);
|
||||||
|
}
|
||||||
|
current = current[key];
|
||||||
|
}
|
||||||
|
return current;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void setFieldByKeypath(
|
||||||
|
nlohmann::json& obj,
|
||||||
|
const std::vector<std::string>& keypath,
|
||||||
|
const nlohmann::json& value) {
|
||||||
|
nlohmann::json* current = &obj;
|
||||||
|
for (size_t i = 0; i < keypath.size() - 1; ++i) {
|
||||||
|
const auto& key = keypath[i];
|
||||||
|
if (!current->contains(key)) {
|
||||||
|
throw std::runtime_error("Keypath not found: " + key);
|
||||||
|
}
|
||||||
|
current = &((*current)[key]);
|
||||||
|
}
|
||||||
|
if (!current->contains(keypath.back())) {
|
||||||
|
throw std::runtime_error("Keypath not found: " + keypath.back());
|
||||||
|
}
|
||||||
|
(*current)[keypath.back()] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Upgrader::Upgrader(std::vector<std::string> kp, UpgraderFunction func)
|
||||||
|
: keypath(std::move(kp)), upgrade_func(std::move(func)) {}
|
||||||
|
|
||||||
|
bool Upgrader::operator<(const Upgrader& other) const {
|
||||||
|
// First compare by depth - deeper paths come first for bottom-up processing
|
||||||
|
if (keypath.size() != other.keypath.size()) {
|
||||||
|
return keypath.size() > other.keypath.size();
|
||||||
|
}
|
||||||
|
// If same depth, compare lexicographically for deterministic ordering
|
||||||
|
return keypath < other.keypath;
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerUpgrader(
|
||||||
|
int version,
|
||||||
|
const std::vector<std::string>& keypath,
|
||||||
|
const UpgraderFunction& upgrade_func) {
|
||||||
|
// Check if an upgrader already exists for this version and keypath
|
||||||
|
auto version_it = upgrader_registry.find(version);
|
||||||
|
if (version_it != upgrader_registry.end()) {
|
||||||
|
const auto& upgraders = version_it->second;
|
||||||
|
|
||||||
|
// Search for existing upgrader with the same keypath
|
||||||
|
for (const auto& existing_upgrader : upgraders) {
|
||||||
|
if (existing_upgrader.keypath == keypath) {
|
||||||
|
std::ostringstream error_stream;
|
||||||
|
error_stream << "Upgrader already registered for version " << version
|
||||||
|
<< " and keypath: ";
|
||||||
|
for (size_t i = 0; i < keypath.size(); ++i) {
|
||||||
|
if (i > 0)
|
||||||
|
error_stream << ".";
|
||||||
|
error_stream << keypath[i];
|
||||||
|
}
|
||||||
|
throw std::runtime_error(error_stream.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upgrader_registry[version].emplace(keypath, upgrade_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerUpgrader(
|
||||||
|
int version,
|
||||||
|
const std::string& dot_keypath,
|
||||||
|
const UpgraderFunction& upgrade_func) {
|
||||||
|
// Convert dot-separated keypath to vector and delegate to main implementation
|
||||||
|
std::vector<std::string> keypath_vector;
|
||||||
|
std::stringstream ss(dot_keypath);
|
||||||
|
std::string component;
|
||||||
|
|
||||||
|
while (std::getline(ss, component, '.')) {
|
||||||
|
if (component.empty()) {
|
||||||
|
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
||||||
|
}
|
||||||
|
keypath_vector.push_back(component);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keypath_vector.empty()) {
|
||||||
|
throw std::invalid_argument("Empty keypath provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
registerUpgrader(version, keypath_vector, upgrade_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath) {
|
||||||
|
auto version_it = upgrader_registry.find(version);
|
||||||
|
if (version_it == upgrader_registry.end()) {
|
||||||
|
return false; // Version not found
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& upgraders = version_it->second;
|
||||||
|
|
||||||
|
// Find the upgrader with matching keypath
|
||||||
|
for (auto it = upgraders.begin(); it != upgraders.end(); ++it) {
|
||||||
|
if (it->keypath == keypath) {
|
||||||
|
upgraders.erase(it);
|
||||||
|
|
||||||
|
// If this was the last upgrader for this version, remove the version
|
||||||
|
// entry
|
||||||
|
if (upgraders.empty()) {
|
||||||
|
upgrader_registry.erase(version_it);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true; // Successfully removed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false; // Upgrader not found
|
||||||
|
}
|
||||||
|
|
||||||
|
bool deregisterUpgrader(int version, const std::string& dot_keypath) {
|
||||||
|
// Convert dot-separated keypath to vector and delegate to main implementation
|
||||||
|
std::vector<std::string> keypath_vector;
|
||||||
|
std::stringstream ss(dot_keypath);
|
||||||
|
std::string component;
|
||||||
|
|
||||||
|
while (std::getline(ss, component, '.')) {
|
||||||
|
if (component.empty()) {
|
||||||
|
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
||||||
|
}
|
||||||
|
keypath_vector.push_back(component);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keypath_vector.empty()) {
|
||||||
|
throw std::invalid_argument("Empty keypath provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
return deregisterUpgrader(version, keypath_vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
void throwUpgraderError(
|
||||||
|
const std::string& upgrader_name,
|
||||||
|
int from_version,
|
||||||
|
const std::string& error_message,
|
||||||
|
const nlohmann::json& problematic_object) {
|
||||||
|
std::ostringstream error_stream;
|
||||||
|
error_stream << "Error in upgrader '" << upgrader_name << "' "
|
||||||
|
<< "while upgrading from version " << from_version
|
||||||
|
<< " to version " << from_version + 1 << ": " << error_message;
|
||||||
|
|
||||||
|
if (!problematic_object.empty()) {
|
||||||
|
error_stream << "\nProblematic object: " << problematic_object.dump(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
||||||
|
auto current_artifact = artifact;
|
||||||
|
|
||||||
|
// Validate that the artifact contains required schema version information
|
||||||
|
if (!current_artifact.contains("schema_version")) {
|
||||||
|
throw std::runtime_error("Missing schema_version field in artifact");
|
||||||
|
}
|
||||||
|
|
||||||
|
int current_version = current_artifact["schema_version"]["major"];
|
||||||
|
|
||||||
|
// Iteratively apply upgraders until target version is reached or no more are
|
||||||
|
// available
|
||||||
|
while (current_version < target_version) {
|
||||||
|
// Look up upgraders for the current version
|
||||||
|
const auto& upgraders = getUpgrader(current_version);
|
||||||
|
|
||||||
|
if (upgraders.empty()) {
|
||||||
|
// No more upgraders available - stop upgrading
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply all upgraders for this version in bottom-up order
|
||||||
|
// (deeper keypaths first to prevent parent/child conflicts)
|
||||||
|
for (const auto& upgrader : upgraders) {
|
||||||
|
// Extract the field to be upgraded using its keypath
|
||||||
|
auto field_to_upgrade =
|
||||||
|
getFieldByKeypath(current_artifact, upgrader.keypath);
|
||||||
|
|
||||||
|
// Apply the upgrade transformation
|
||||||
|
auto upgraded_field = upgrader.upgrade_func(field_to_upgrade);
|
||||||
|
|
||||||
|
// Update the artifact with the upgraded field
|
||||||
|
setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move to the next version for potential additional upgrades
|
||||||
|
current_version++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update schema version to reflect the final upgraded version
|
||||||
|
if (current_artifact["schema_version"]["major"] != current_version) {
|
||||||
|
current_artifact["schema_version"]["major"] = current_version;
|
||||||
|
// Reset minor version to 0 - the correct minor version should be set
|
||||||
|
// when converting the json to in memory representation of ExportedProgram
|
||||||
|
current_artifact["schema_version"]["minor"] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we reached the target version if requested
|
||||||
|
if (current_version != target_version) {
|
||||||
|
std::ostringstream error_stream;
|
||||||
|
error_stream
|
||||||
|
<< "Failed to upgrade to target version " << target_version
|
||||||
|
<< ". Final version reached: " << current_version
|
||||||
|
<< ". This may indicate missing upgraders for intermediate versions.";
|
||||||
|
throw std::runtime_error(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return current_artifact;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::_export
|
||||||
118
torch/csrc/export/upgrader.h
Normal file
118
torch/csrc/export/upgrader.h
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
/// Function type for upgrading JSON fields during schema version migration.
|
||||||
|
/// Takes a JSON field and returns the upgraded version of that field.
|
||||||
|
using UpgraderFunction = std::function<nlohmann::json(const nlohmann::json&)>;
|
||||||
|
|
||||||
|
/// Structure containing upgrader information for a specific keypath.
|
||||||
|
/// The version is stored as the map key in the registry, so it's not
|
||||||
|
/// duplicated here.
|
||||||
|
struct Upgrader {
|
||||||
|
/// Path to the field that should be upgraded (e.g., {"graph_module", "graph",
|
||||||
|
/// "nodes"}) Assuming top-level is a JSON object that represents
|
||||||
|
/// ExportedProgram
|
||||||
|
std::vector<std::string> keypath;
|
||||||
|
|
||||||
|
/// Function that performs the actual upgrade transformation
|
||||||
|
UpgraderFunction upgrade_func;
|
||||||
|
|
||||||
|
/// Constructor for creating an upgrader with keypath and function
|
||||||
|
Upgrader(std::vector<std::string> kp, UpgraderFunction func);
|
||||||
|
|
||||||
|
/// Comparator for maintaining bottom-up ordering in the registry.
|
||||||
|
/// Deeper keypaths are processed first to ensure safe upgrade application
|
||||||
|
/// without conflicts between parent and child field modifications.
|
||||||
|
bool operator<(const Upgrader& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Register an upgrader function for a specific schema version and keypath.
|
||||||
|
///
|
||||||
|
/// This function allows registration of custom upgrade logic that will be
|
||||||
|
/// applied when upgrading artifacts from the specified version. Upgraders
|
||||||
|
/// are applied in bottom-up order (deeper keypaths first) to prevent
|
||||||
|
/// conflicts between parent and child field modifications.
|
||||||
|
///
|
||||||
|
/// @param version The schema version this upgrader applies to
|
||||||
|
/// @param keypath The key path to the field that should be upgraded
|
||||||
|
/// @param upgrade_func Function that performs the upgrade transformation
|
||||||
|
void registerUpgrader(
|
||||||
|
int version,
|
||||||
|
const std::vector<std::string>& keypath,
|
||||||
|
const UpgraderFunction& upgrade_func);
|
||||||
|
|
||||||
|
/// Register an upgrader function using dot-separated keypath notation.
|
||||||
|
///
|
||||||
|
/// Convenience overload that accepts dot-separated keypath strings for
|
||||||
|
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
||||||
|
/// {"graph_module", "graph", "nodes"}.
|
||||||
|
///
|
||||||
|
/// @param version The schema version this upgrader applies to
|
||||||
|
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
||||||
|
/// @param upgrade_func Function that performs the upgrade transformation
|
||||||
|
void registerUpgrader(
|
||||||
|
int version,
|
||||||
|
const std::string& dot_keypath,
|
||||||
|
const UpgraderFunction& upgrade_func);
|
||||||
|
|
||||||
|
/// Deregister an upgrader function for a specific schema version and keypath.
|
||||||
|
///
|
||||||
|
/// This function allows removal of previously registered upgrade logic for
|
||||||
|
/// the specified version and keypath. This is useful for testing scenarios
|
||||||
|
/// where you need to clean up registered upgraders or modify upgrader
|
||||||
|
/// behavior dynamically.
|
||||||
|
///
|
||||||
|
/// @param version The schema version to deregister the upgrader from
|
||||||
|
/// @param keypath The key path to the field that should be deregistered
|
||||||
|
/// @return true if an upgrader was found and removed, false otherwise
|
||||||
|
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath);
|
||||||
|
|
||||||
|
/// Deregister an upgrader function using dot-separated keypath notation.
|
||||||
|
///
|
||||||
|
/// Convenience overload that accepts dot-separated keypath strings for
|
||||||
|
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
||||||
|
/// {"graph_module", "graph", "nodes"}.
|
||||||
|
///
|
||||||
|
/// @param version The schema version to deregister the upgrader from
|
||||||
|
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
||||||
|
/// @return true if an upgrader was found and removed, false otherwise
|
||||||
|
bool deregisterUpgrader(int version, const std::string& dot_keypath);
|
||||||
|
|
||||||
|
/// Utility function for throwing consistent upgrader errors.
|
||||||
|
///
|
||||||
|
/// This function formats error messages in a standardized way for upgrader
|
||||||
|
/// failures, including version information and optional problematic object
|
||||||
|
/// details for debugging.
|
||||||
|
///
|
||||||
|
/// @param upgrader_name Name of the upgrader that failed
|
||||||
|
/// @param from_version Source schema version being upgraded from
|
||||||
|
/// @param error_message Descriptive error message
|
||||||
|
/// @param problematic_object Optional JSON object that caused the error
|
||||||
|
/// @throws std::runtime_error Always throws with formatted error message
|
||||||
|
void throwUpgraderError(
|
||||||
|
const std::string& upgrader_name,
|
||||||
|
int from_version,
|
||||||
|
const std::string& error_message,
|
||||||
|
const nlohmann::json& problematic_object = nlohmann::json::object());
|
||||||
|
|
||||||
|
/// Upgrade a JSON artifact to a specific target version with available
|
||||||
|
/// upgraders until a target version is reached.
|
||||||
|
///
|
||||||
|
/// This handles major version upgrade only. For minor version upgrade,
|
||||||
|
/// e.g. adding a new field with default value, it's automatically handled by
|
||||||
|
/// the default constructor in generated_serialization_types.h.
|
||||||
|
///
|
||||||
|
/// @param artifact The JSON artifact to upgrade
|
||||||
|
/// @param target_version The target schema version to upgrade to
|
||||||
|
/// @return The upgraded JSON artifact with updated schema version
|
||||||
|
/// @throws std::runtime_error if artifact is missing schema_version field
|
||||||
|
/// @throws std::runtime_error if final version doesn't match target version
|
||||||
|
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version);
|
||||||
|
|
||||||
|
} // namespace torch::_export
|
||||||
Loading…
Reference in New Issue
Block a user