From aeffb68d3466635e4e95c50bffd7dfebaba94da2 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Fri, 27 Jun 2025 21:05:14 -0700 Subject: [PATCH] [schema_upgrader] add C++ upgrader for json based upgrading (#156761) Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi --- build_variables.bzl | 2 + test/export/test_upgrader.py | 284 ++++++++++++++++++++++++ torch/csrc/export/example_upgraders.cpp | 89 ++++++++ torch/csrc/export/example_upgraders.h | 15 ++ torch/csrc/export/pybind.cpp | 28 ++- torch/csrc/export/upgrader.cpp | 242 ++++++++++++++++++++ torch/csrc/export/upgrader.h | 118 ++++++++++ 7 files changed, 777 insertions(+), 1 deletion(-) create mode 100644 test/export/test_upgrader.py create mode 100644 torch/csrc/export/example_upgraders.cpp create mode 100644 torch/csrc/export/example_upgraders.h create mode 100644 torch/csrc/export/upgrader.cpp create mode 100644 torch/csrc/export/upgrader.h diff --git a/build_variables.bzl b/build_variables.bzl index 76f21a6c1ac..d3376322c40 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -895,6 +895,8 @@ libtorch_python_core_sources = [ "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", "torch/csrc/export/pybind.cpp", + "torch/csrc/export/upgrader.cpp", + "torch/csrc/export/example_upgraders.cpp", "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", diff --git a/test/export/test_upgrader.py b/test/export/test_upgrader.py new file mode 100644 index 00000000000..0c36b28750f --- /dev/null +++ b/test/export/test_upgrader.py @@ -0,0 +1,284 @@ +# Owner(s): ["oncall: export"] + +import json + +import torch +from torch.testing._internal.common_utils import TestCase + + +class TestUpgrader(TestCase): + def setUp(self) -> None: + # Register example upgraders dynamically + torch._C._export.register_example_upgraders() + + def tearDown(self) -> None: + # Clean up registered upgraders + torch._C._export.deregister_example_upgraders() + + def test_nn_module_stack_transformation_from_v0(self): + """Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0""" + + # Create a mock JSON object that simulates version 0 schema + # with nn_module_stack as a string that needs to be upgraded + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": "original_stack_info", + "other_field": "some_value", + }, + }, + { + "target": "aten.mul.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": "another_stack", + "stack_trace": "some trace", + }, + }, + ] + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify nn_module_stack was prepended with "test_upgrader_" + nodes = upgraded_json["graph_module"]["graph"]["nodes"] + + # Check first node + first_node_metadata = nodes[0]["metadata"] + nn_stack = first_node_metadata["nn_module_stack"] + self.assertIsInstance(nn_stack, str) + self.assertEqual(nn_stack, "test_upgrader_original_stack_info") + # Other metadata should be unchanged + self.assertEqual(first_node_metadata["other_field"], "some_value") + + # Check second node + second_node_metadata = nodes[1]["metadata"] + nn_stack2 = second_node_metadata["nn_module_stack"] + self.assertIsInstance(nn_stack2, str) + self.assertEqual(nn_stack2, "test_upgrader_another_stack") + # Other metadata should be unchanged + self.assertEqual(second_node_metadata["stack_trace"], "some trace") + + def test_nn_module_stack_error_handling_invalid_type(self): + """Test error handling when nn_module_stack is not a string""" + + # Test case: nn_module_stack is not a string + mock_json_invalid_type = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": 42 # Invalid: should be string + }, + } + ] + } + }, + } + + with self.assertRaisesRegex( + RuntimeError, + "Error in upgrader 'version_0_upgrader_registered'", + ): + serialized_json = json.dumps(mock_json_invalid_type) + torch._C._export.upgrade(serialized_json, 2) + + def test_nodes_without_metadata_handled_gracefully(self): + """Test that nodes without metadata or nn_module_stack are handled gracefully""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + # No metadata field + }, + { + "target": "aten.mul.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "stack_trace": "some trace" + # No nn_module_stack field + }, + }, + ] + } + }, + } + + # Should not raise an error + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify nodes are unchanged + nodes = upgraded_json["graph_module"]["graph"]["nodes"] + self.assertEqual(len(nodes), 2) + + # First node should have no metadata + self.assertNotIn("metadata", nodes[0]) + + # Second node should have unchanged metadata + self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace") + self.assertNotIn("nn_module_stack", nodes[1]["metadata"]) + + def test_field_renaming_chain_from_v0_complete(self): + """Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": {"nn_module_stack": "test_stack"}, + } + ], + "old_test_field": "original_value", + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2 + graph = upgraded_json["graph_module"]["graph"] + self.assertIn("new_test_field2", graph) + self.assertEqual(graph["new_test_field2"], "original_value") + self.assertNotIn("old_test_field", graph) + self.assertNotIn("new_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + # Verify the nn_module_stack was also upgraded by the other upgrader + nodes = graph["nodes"] + self.assertEqual( + nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack" + ) + + def test_field_renaming_chain_from_v0_missing_field(self): + """Test that upgraders work gracefully when old_test_field doesn't exist""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [], + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify no field transformations occurred since old_test_field didn't exist + graph = upgraded_json["graph_module"]["graph"] + self.assertNotIn("new_test_field2", graph) + self.assertNotIn("new_test_field", graph) + self.assertNotIn("old_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + def test_field_renaming_from_v1_partial_chain(self): + """Test partial upgrade chain starting from v1: new_test_field -> new_test_field2""" + + mock_json = { + "schema_version": {"major": 1, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [], + "new_test_field": "test_value", + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify new_test_field was renamed to new_test_field2 + graph = upgraded_json["graph_module"]["graph"] + self.assertIn("new_test_field2", graph) + self.assertEqual(graph["new_test_field2"], "test_value") + self.assertNotIn("new_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/csrc/export/example_upgraders.cpp b/torch/csrc/export/example_upgraders.cpp new file mode 100644 index 00000000000..398c01301f0 --- /dev/null +++ b/torch/csrc/export/example_upgraders.cpp @@ -0,0 +1,89 @@ +#include +#include + +namespace torch::_export { + +/// Register test upgraders for the upgrader system. +/// and shows some common upgrade patterns. +static bool test_upgraders_registered = false; + +void registerExampleUpgraders() { + if (test_upgraders_registered) { + return; + } + + registerUpgrader( + 0, + "graph_module.graph.nodes", + [](const nlohmann::json& nodes_array) -> nlohmann::json { + nlohmann::json upgraded_nodes = nodes_array; + + // Process each node in the nodes array + for (auto& node : upgraded_nodes) { + if (node.contains("metadata") && node["metadata"].is_object()) { + // Process each metadata key-value pair + for (auto& [key, value] : node["metadata"].items()) { + if (key == "nn_module_stack") { + // Transform nn_module_stack values by prepending prefix + if (value.is_string()) { + std::string stack_str = value.get(); + value = "test_upgrader_" + stack_str; + } else { + throwUpgraderError( + "version_0_upgrader_registered", + 0, + "nn_module_stack metadata value must be a string, got: " + + std::string(value.type_name()), + node); + } + } + // Other metadata keys remain unchanged + } + } + } + + return upgraded_nodes; + }); + + registerUpgrader( + 0, + "graph_module.graph", + [](const nlohmann::json& graph_obj) -> nlohmann::json { + nlohmann::json upgraded_graph = graph_obj; + + // Rename field if it exists in the graph object + if (upgraded_graph.contains("old_test_field")) { + upgraded_graph["new_test_field"] = upgraded_graph["old_test_field"]; + upgraded_graph.erase("old_test_field"); + } + + return upgraded_graph; + }); + + registerUpgrader( + 1, + std::vector{"graph_module", "graph"}, + [](const nlohmann::json& graph_obj) -> nlohmann::json { + nlohmann::json upgraded_graph = graph_obj; + + // Continue the field renaming chain from version 0 + if (upgraded_graph.contains("new_test_field")) { + upgraded_graph["new_test_field2"] = upgraded_graph["new_test_field"]; + upgraded_graph.erase("new_test_field"); + } + + return upgraded_graph; + }); + + test_upgraders_registered = true; +} + +/// Deregister test upgraders for the upgrader system. +void deregisterExampleUpgraders() { + deregisterUpgrader(0, "graph_module.graph.nodes"); + deregisterUpgrader(0, "graph_module.graph"); + deregisterUpgrader(1, std::vector{"graph_module", "graph"}); + test_upgraders_registered = false; +} + +} // namespace torch::_export diff --git a/torch/csrc/export/example_upgraders.h b/torch/csrc/export/example_upgraders.h new file mode 100644 index 00000000000..40e1fb14e72 --- /dev/null +++ b/torch/csrc/export/example_upgraders.h @@ -0,0 +1,15 @@ +#pragma once + +namespace torch::_export { + +/// Register example upgraders for the upgrader system for testing. +/// This function demonstrates common upgrade patterns and is primarily +/// used for testing and demonstration purposes. +void registerExampleUpgraders(); + +/// Deregister example upgraders for the upgrader system for testing. +/// This function cleans up the example upgraders that were registered +/// by registerExampleUpgraders(). +void deregisterExampleUpgraders(); + +} // namespace torch::_export diff --git a/torch/csrc/export/pybind.cpp b/torch/csrc/export/pybind.cpp index 65206d06dbe..eedd8666ea1 100644 --- a/torch/csrc/export/pybind.cpp +++ b/torch/csrc/export/pybind.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include @@ -15,13 +17,37 @@ void initExportBindings(PyObject* module) { exportModule.def( "deserialize_exported_program", [](const std::string& serialized) { - return nlohmann::json::parse(serialized).get(); + auto parsed = nlohmann::json::parse(serialized); + + // Query the current Python schema version as target + // TODO: expose schema_version in gneerated_serialization_types.h and + // access it here directly. + py::module_ schema_module = + py::module_::import("torch._export.serde.schema"); + py::tuple schema_version_tuple = schema_module.attr("SCHEMA_VERSION"); + int target_version = schema_version_tuple[0].cast(); + + auto upgraded = upgrade(parsed, target_version); + return upgraded.get(); }); exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) { return nlohmann::json(ep).dump(); }); + exportModule.def( + "upgrade", [](const std::string& serialized_json, int target_version) { + auto parsed = nlohmann::json::parse(serialized_json); + auto upgraded = upgrade(parsed, target_version); + return upgraded.dump(); + }); + + exportModule.def( + "register_example_upgraders", []() { registerExampleUpgraders(); }); + + exportModule.def( + "deregister_example_upgraders", []() { deregisterExampleUpgraders(); }); + for (const auto& entry : torch::_export::archive_spec::kAllConstants) { pt2ArchiveModule.attr(entry.first) = entry.second; } diff --git a/torch/csrc/export/upgrader.cpp b/torch/csrc/export/upgrader.cpp new file mode 100644 index 00000000000..9f92239840b --- /dev/null +++ b/torch/csrc/export/upgrader.cpp @@ -0,0 +1,242 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch::_export { + +// Global upgrader registry organized by version. +// Using std::multiset to maintain automatic bottom-up ordering where +// deeper keypaths are processed before shallower ones. +static std::map> upgrader_registry; + +static const std::multiset& getUpgrader(int current_version) { + static const std::multiset empty_upgraders; + auto it = upgrader_registry.find(current_version); + if (it != upgrader_registry.end()) { + return it->second; + } + return empty_upgraders; +} + +static nlohmann::json getFieldByKeypath( + const nlohmann::json& obj, + const std::vector& keypath) { + nlohmann::json current = obj; + for (const auto& key : keypath) { + if (!current.contains(key)) { + throw std::runtime_error("Keypath not found: " + key); + } + current = current[key]; + } + return current; +} + +static void setFieldByKeypath( + nlohmann::json& obj, + const std::vector& keypath, + const nlohmann::json& value) { + nlohmann::json* current = &obj; + for (size_t i = 0; i < keypath.size() - 1; ++i) { + const auto& key = keypath[i]; + if (!current->contains(key)) { + throw std::runtime_error("Keypath not found: " + key); + } + current = &((*current)[key]); + } + if (!current->contains(keypath.back())) { + throw std::runtime_error("Keypath not found: " + keypath.back()); + } + (*current)[keypath.back()] = value; +} + +Upgrader::Upgrader(std::vector kp, UpgraderFunction func) + : keypath(std::move(kp)), upgrade_func(std::move(func)) {} + +bool Upgrader::operator<(const Upgrader& other) const { + // First compare by depth - deeper paths come first for bottom-up processing + if (keypath.size() != other.keypath.size()) { + return keypath.size() > other.keypath.size(); + } + // If same depth, compare lexicographically for deterministic ordering + return keypath < other.keypath; +} + +void registerUpgrader( + int version, + const std::vector& keypath, + const UpgraderFunction& upgrade_func) { + // Check if an upgrader already exists for this version and keypath + auto version_it = upgrader_registry.find(version); + if (version_it != upgrader_registry.end()) { + const auto& upgraders = version_it->second; + + // Search for existing upgrader with the same keypath + for (const auto& existing_upgrader : upgraders) { + if (existing_upgrader.keypath == keypath) { + std::ostringstream error_stream; + error_stream << "Upgrader already registered for version " << version + << " and keypath: "; + for (size_t i = 0; i < keypath.size(); ++i) { + if (i > 0) + error_stream << "."; + error_stream << keypath[i]; + } + throw std::runtime_error(error_stream.str()); + } + } + } + + upgrader_registry[version].emplace(keypath, upgrade_func); +} + +void registerUpgrader( + int version, + const std::string& dot_keypath, + const UpgraderFunction& upgrade_func) { + // Convert dot-separated keypath to vector and delegate to main implementation + std::vector keypath_vector; + std::stringstream ss(dot_keypath); + std::string component; + + while (std::getline(ss, component, '.')) { + if (component.empty()) { + throw std::invalid_argument("Empty component in keypath: " + dot_keypath); + } + keypath_vector.push_back(component); + } + + if (keypath_vector.empty()) { + throw std::invalid_argument("Empty keypath provided"); + } + + registerUpgrader(version, keypath_vector, upgrade_func); +} + +bool deregisterUpgrader(int version, const std::vector& keypath) { + auto version_it = upgrader_registry.find(version); + if (version_it == upgrader_registry.end()) { + return false; // Version not found + } + + auto& upgraders = version_it->second; + + // Find the upgrader with matching keypath + for (auto it = upgraders.begin(); it != upgraders.end(); ++it) { + if (it->keypath == keypath) { + upgraders.erase(it); + + // If this was the last upgrader for this version, remove the version + // entry + if (upgraders.empty()) { + upgrader_registry.erase(version_it); + } + + return true; // Successfully removed + } + } + + return false; // Upgrader not found +} + +bool deregisterUpgrader(int version, const std::string& dot_keypath) { + // Convert dot-separated keypath to vector and delegate to main implementation + std::vector keypath_vector; + std::stringstream ss(dot_keypath); + std::string component; + + while (std::getline(ss, component, '.')) { + if (component.empty()) { + throw std::invalid_argument("Empty component in keypath: " + dot_keypath); + } + keypath_vector.push_back(component); + } + + if (keypath_vector.empty()) { + throw std::invalid_argument("Empty keypath provided"); + } + + return deregisterUpgrader(version, keypath_vector); +} + +void throwUpgraderError( + const std::string& upgrader_name, + int from_version, + const std::string& error_message, + const nlohmann::json& problematic_object) { + std::ostringstream error_stream; + error_stream << "Error in upgrader '" << upgrader_name << "' " + << "while upgrading from version " << from_version + << " to version " << from_version + 1 << ": " << error_message; + + if (!problematic_object.empty()) { + error_stream << "\nProblematic object: " << problematic_object.dump(2); + } + + throw std::runtime_error(error_stream.str()); +} + +nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) { + auto current_artifact = artifact; + + // Validate that the artifact contains required schema version information + if (!current_artifact.contains("schema_version")) { + throw std::runtime_error("Missing schema_version field in artifact"); + } + + int current_version = current_artifact["schema_version"]["major"]; + + // Iteratively apply upgraders until target version is reached or no more are + // available + while (current_version < target_version) { + // Look up upgraders for the current version + const auto& upgraders = getUpgrader(current_version); + + if (upgraders.empty()) { + // No more upgraders available - stop upgrading + break; + } + + // Apply all upgraders for this version in bottom-up order + // (deeper keypaths first to prevent parent/child conflicts) + for (const auto& upgrader : upgraders) { + // Extract the field to be upgraded using its keypath + auto field_to_upgrade = + getFieldByKeypath(current_artifact, upgrader.keypath); + + // Apply the upgrade transformation + auto upgraded_field = upgrader.upgrade_func(field_to_upgrade); + + // Update the artifact with the upgraded field + setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field); + } + + // Move to the next version for potential additional upgrades + current_version++; + } + + // Update schema version to reflect the final upgraded version + if (current_artifact["schema_version"]["major"] != current_version) { + current_artifact["schema_version"]["major"] = current_version; + // Reset minor version to 0 - the correct minor version should be set + // when converting the json to in memory representation of ExportedProgram + current_artifact["schema_version"]["minor"] = 0; + } + + // Validate that we reached the target version if requested + if (current_version != target_version) { + std::ostringstream error_stream; + error_stream + << "Failed to upgrade to target version " << target_version + << ". Final version reached: " << current_version + << ". This may indicate missing upgraders for intermediate versions."; + throw std::runtime_error(error_stream.str()); + } + + return current_artifact; +} + +} // namespace torch::_export diff --git a/torch/csrc/export/upgrader.h b/torch/csrc/export/upgrader.h new file mode 100644 index 00000000000..c9e9b8f7ff1 --- /dev/null +++ b/torch/csrc/export/upgrader.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::_export { + +/// Function type for upgrading JSON fields during schema version migration. +/// Takes a JSON field and returns the upgraded version of that field. +using UpgraderFunction = std::function; + +/// Structure containing upgrader information for a specific keypath. +/// The version is stored as the map key in the registry, so it's not +/// duplicated here. +struct Upgrader { + /// Path to the field that should be upgraded (e.g., {"graph_module", "graph", + /// "nodes"}) Assuming top-level is a JSON object that represents + /// ExportedProgram + std::vector keypath; + + /// Function that performs the actual upgrade transformation + UpgraderFunction upgrade_func; + + /// Constructor for creating an upgrader with keypath and function + Upgrader(std::vector kp, UpgraderFunction func); + + /// Comparator for maintaining bottom-up ordering in the registry. + /// Deeper keypaths are processed first to ensure safe upgrade application + /// without conflicts between parent and child field modifications. + bool operator<(const Upgrader& other) const; +}; + +/// Register an upgrader function for a specific schema version and keypath. +/// +/// This function allows registration of custom upgrade logic that will be +/// applied when upgrading artifacts from the specified version. Upgraders +/// are applied in bottom-up order (deeper keypaths first) to prevent +/// conflicts between parent and child field modifications. +/// +/// @param version The schema version this upgrader applies to +/// @param keypath The key path to the field that should be upgraded +/// @param upgrade_func Function that performs the upgrade transformation +void registerUpgrader( + int version, + const std::vector& keypath, + const UpgraderFunction& upgrade_func); + +/// Register an upgrader function using dot-separated keypath notation. +/// +/// Convenience overload that accepts dot-separated keypath strings for +/// simpler syntax. For example: "graph_module.graph.nodes" instead of +/// {"graph_module", "graph", "nodes"}. +/// +/// @param version The schema version this upgrader applies to +/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes") +/// @param upgrade_func Function that performs the upgrade transformation +void registerUpgrader( + int version, + const std::string& dot_keypath, + const UpgraderFunction& upgrade_func); + +/// Deregister an upgrader function for a specific schema version and keypath. +/// +/// This function allows removal of previously registered upgrade logic for +/// the specified version and keypath. This is useful for testing scenarios +/// where you need to clean up registered upgraders or modify upgrader +/// behavior dynamically. +/// +/// @param version The schema version to deregister the upgrader from +/// @param keypath The key path to the field that should be deregistered +/// @return true if an upgrader was found and removed, false otherwise +bool deregisterUpgrader(int version, const std::vector& keypath); + +/// Deregister an upgrader function using dot-separated keypath notation. +/// +/// Convenience overload that accepts dot-separated keypath strings for +/// simpler syntax. For example: "graph_module.graph.nodes" instead of +/// {"graph_module", "graph", "nodes"}. +/// +/// @param version The schema version to deregister the upgrader from +/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes") +/// @return true if an upgrader was found and removed, false otherwise +bool deregisterUpgrader(int version, const std::string& dot_keypath); + +/// Utility function for throwing consistent upgrader errors. +/// +/// This function formats error messages in a standardized way for upgrader +/// failures, including version information and optional problematic object +/// details for debugging. +/// +/// @param upgrader_name Name of the upgrader that failed +/// @param from_version Source schema version being upgraded from +/// @param error_message Descriptive error message +/// @param problematic_object Optional JSON object that caused the error +/// @throws std::runtime_error Always throws with formatted error message +void throwUpgraderError( + const std::string& upgrader_name, + int from_version, + const std::string& error_message, + const nlohmann::json& problematic_object = nlohmann::json::object()); + +/// Upgrade a JSON artifact to a specific target version with available +/// upgraders until a target version is reached. +/// +/// This handles major version upgrade only. For minor version upgrade, +/// e.g. adding a new field with default value, it's automatically handled by +/// the default constructor in generated_serialization_types.h. +/// +/// @param artifact The JSON artifact to upgrade +/// @param target_version The target schema version to upgrade to +/// @return The upgraded JSON artifact with updated schema version +/// @throws std::runtime_error if artifact is missing schema_version field +/// @throws std::runtime_error if final version doesn't match target version +nlohmann::json upgrade(const nlohmann::json& artifact, int target_version); + +} // namespace torch::_export