From b0fdca885547efedddc4e9cfe78a5e0d7838e9f0 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Wed, 5 Jan 2022 23:55:49 -0800 Subject: [PATCH] Bump version number to 7 and compile old operators with old schema (#68358) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68358 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33433730 Pulled By: tugsbayasgalan fbshipit-source-id: 202c58365bae13195d3545cefcb0da9162b02151 --- caffe2/serialize/inline_container.h | 6 + caffe2/serialize/versions.h | 27 +- test/cpp/jit/CMakeLists.txt | 1 + test/cpp/jit/test_load_upgraders.cpp | 45 ++ test/jit/test_legacy_upgraders.py | 553 ++++++++++++++++++ test/jit/test_save_load.py | 17 - test/jit/test_save_load_for_op_version.py | 198 ------- test/jit/test_upgraders.py | 130 +++- test/test_jit.py | 1 + tools/build_variables.bzl | 4 +- torch/_C/__init__.pyi.in | 2 +- torch/csrc/jit/frontend/builtin_functions.cpp | 6 +- torch/csrc/jit/frontend/ir_emitter.cpp | 9 + torch/csrc/jit/frontend/schema_matching.cpp | 18 +- torch/csrc/jit/frontend/schema_matching.h | 2 + torch/csrc/jit/frontend/sugared_value.h | 3 + torch/csrc/jit/frontend/versioned_symbols.cpp | 4 +- torch/csrc/jit/frontend/versioned_symbols.h | 5 +- .../csrc/jit/operator_upgraders/upgraders.cpp | 26 +- torch/csrc/jit/operator_upgraders/upgraders.h | 15 +- .../operator_upgraders/upgraders_entry.cpp | 78 +++ .../jit/operator_upgraders/upgraders_entry.h | 14 + .../operator_upgraders/upgraders_guard.cpp | 12 + .../jit/operator_upgraders/upgraders_guard.h | 10 + .../jit/operator_upgraders/version_map.cpp | 4 +- .../passes/replacement_of_old_operators.cpp | 37 +- torch/csrc/jit/python/script_init.cpp | 4 +- torch/csrc/jit/serialization/import.cpp | 6 + .../csrc/jit/serialization/import_source.cpp | 4 +- torch/csrc/jit/serialization/python_print.cpp | 21 +- torch/jit/_serialization.py | 5 - torch/jit/operator_upgraders.py | 7 - 32 files changed, 993 insertions(+), 281 deletions(-) create mode 100644 test/cpp/jit/test_load_upgraders.cpp create mode 100644 test/jit/test_legacy_upgraders.py create mode 100644 torch/csrc/jit/operator_upgraders/upgraders_entry.cpp create mode 100644 torch/csrc/jit/operator_upgraders/upgraders_entry.h create mode 100644 torch/csrc/jit/operator_upgraders/upgraders_guard.cpp create mode 100644 torch/csrc/jit/operator_upgraders/upgraders_guard.h diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 4eb1b8e71ce..139174fa3d6 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -164,7 +164,13 @@ class TORCH_API PyTorchStreamWriter final { std::string padding_; std::ofstream file_stream_; std::function writer_func_; + // This number will be updated when the model has operators + // that have valid upgraders. +#if ENABLE_UPGRADERS + uint64_t version_ = kMinProducedFileFormatVersion; +#else uint64_t version_ = kProducedFileFormatVersion; +#endif bool finalized_ = false; bool err_seen_ = false; friend size_t ostream_write_func( diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 9fea8fc5590..ee682c68628 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -1,12 +1,21 @@ #pragma once - #include namespace caffe2 { namespace serialize { +// Flag that controls if we want to enable upgraders +// in the server side. When this flag is set to False, +// it will switch to old dynamic versioning approach +#define ENABLE_UPGRADERS false + constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; + +#if ENABLE_UPGRADERS +constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L; +#else constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; +#endif // Versions (i.e. why was the version number bumped?) @@ -47,7 +56,23 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // 5. (Dynamic) Stops torch.full inferring a floating point dtype // when given bool or integer fill values. // 6. Write version string to `./data/version` instead of `version`. + +#if ENABLE_UPGRADERS +// This is set to 7 from 3 due to a different interpretation of what +// file format version is. Whenever there is new upgrader introduced, +// this number should be bumped. +// 1. aten::div is changed at version 4 +// 2. aten::full is changed at version 5 +// 3. torch.package uses version 6 +constexpr uint64_t kProducedFileFormatVersion = 0x7L; +#else constexpr uint64_t kProducedFileFormatVersion = 0x3L; +#endif + +// Absolute minimum version we will write packages. This +// means that every package from now on will always be +// greater than this number. +constexpr uint64_t kMinProducedFileFormatVersion = 0x3L; // The version we write when the archive contains bytecode. // It must be higher or eq to kProducedFileFormatVersion. diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index ca6b869f71b..a86910bd0c5 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -42,6 +42,7 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_alias_analysis.cpp ${JIT_TEST_ROOT}/test_argument_spec.cpp ${JIT_TEST_ROOT}/test_autodiff.cpp + ${JIT_TEST_ROOT}/test_load_upgraders.cpp ${JIT_TEST_ROOT}/test_op_replacement.cpp ${JIT_TEST_ROOT}/test_upgrader_utils.cpp ${JIT_TEST_ROOT}/test_backend.cpp diff --git a/test/cpp/jit/test_load_upgraders.cpp b/test/cpp/jit/test_load_upgraders.cpp new file mode 100644 index 00000000000..ddbe46c8ada --- /dev/null +++ b/test/cpp/jit/test_load_upgraders.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { + +#if ENABLE_UPGRADERS +// Basic tests to check if C++ torch::jit::load +// can load the upgraders fine +// TODO (tugsuu) add more tests +TEST(UpgraderLoad, CanPopulateUpgradersGraph) { + Module m("m"); + m.define(R"( + def forward(self, x: Tensor): + b = 5 + return torch.div(x, b) + )"); + std::stringstream ms; + m.save(ms); + auto loaded_m = torch::jit::load(ms); + auto version_map = get_operator_version_map(); + auto upgraders = dump_upgraders_map(); + + for (const auto& entry : version_map) { + auto list_of_upgraders_for_op = entry.second; + for (const auto& upgrader_entry : list_of_upgraders_for_op) { + EXPECT_TRUE( + upgraders.find(upgrader_entry.upgrader_name) != upgraders.end()); + } + } + + auto test_graph = loaded_m.get_method("forward").graph(); + // should have saved with version 4, so it is still up to date + testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph); +} +#endif + +} // namespace jit +} // namespace torch diff --git a/test/jit/test_legacy_upgraders.py b/test/jit/test_legacy_upgraders.py new file mode 100644 index 00000000000..e4c0d588ee5 --- /dev/null +++ b/test/jit/test_legacy_upgraders.py @@ -0,0 +1,553 @@ +# Owner(s): ["oncall: jit"] + +from itertools import product as product +import io +import os +import random +import sys +import unittest + +import torch +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +# Legacy test cases for dynamic operator versioning +class TestLegacyUpgraders(JitTestCase): + reason: str = "Skipped due to new global operator versioning for operators" + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_symbols(self): + """ + Tests Torchscript symbol versioning. See note [Versioned Symbols]. + This test uses an undocumented, test-only function + torch._test_serialization_subcmul. + This function is implemented as (a - alpha * b) with a default value + of 1 for alpha. In file format version 2, however, it was implemented + as (b - alpha * a) with a default value of 2 for alpha. + This test verifies a module seralized with file format version 2 + exhibits the old behavior, and that the same module newly serialized + exhibits the current behavior. + """ + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a, b, alpha: float): + no_alpha = torch._test_serialization_subcmul(a, b) + with_alpha = torch._test_serialization_subcmul(a, b, alpha) + return no_alpha, with_alpha + + def historic_subcmul(a, b, alpha=2): + return b - alpha * a + + def current_subcmul(a, b, alpha=1): + return a - alpha * b + + # Loads and verifies the historic behavior of the module + # that was serialized with version 2 + module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt") + a = torch.randn((5,)) + b = torch.randn((5,)) + alpha = random.random() + args = (a, b, alpha) + no_alpha_v2, with_alpha_v2 = module_v2(*args) + self.assertEqual(no_alpha_v2, historic_subcmul(a, b)) + self.assertEqual(with_alpha_v2, historic_subcmul(*args)) + + # Scripts, saves, loads and verifies the current behavior of the module + scripted_module = torch.jit.script(MyModule()) + buffer = io.BytesIO() + torch.jit.save(scripted_module, buffer) + buffer.seek(0) + module_current = torch.jit.load(buffer) + no_alpha_current, with_alpha_current = module_current(*args) + self.assertEqual(no_alpha_current, current_subcmul(a, b)) + self.assertEqual(with_alpha_current, current_subcmul(*args)) + + # Helper that returns the module after saving and loading + def _save_load_module(self, m): + scripted_module = torch.jit.script(m()) + buffer = io.BytesIO() + torch.jit.save(scripted_module, buffer) + buffer.seek(0) + return torch.jit.load(buffer) + + # Helper which returns the result of a function or the exception the + # function threw. + def _try_fn(self, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + return e + + def _verify_no(self, kind, m): + self._verify_count(kind, m, 0) + + def _verify_count(self, kind, m, count): + node_count = sum(str(n).count(kind) for n in m.graph.nodes()) + self.assertEqual(node_count, count) + + """ + Tests that verify Torchscript remaps aten::div(_) from versions 0-3 + to call either aten::true_divide(_), if an input is a float type, + or truncated aten::divide(_) otherwise. + NOTE: currently compares against current div behavior, too, since + div behavior has not yet been updated. + """ + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_tensor(self): + def historic_div(self, other): + if self.is_floating_point() or other.is_floating_point(): + return self.true_divide(other) + return self.divide(other, rounding_mode='trunc') + + # Tensor x Tensor + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a, b): + result_0 = a / b + result_1 = torch.div(a, b) + result_2 = a.div(b) + + return result_0, result_1, result_2 + + # Loads historic module + try: + v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div + self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::div", current_module, 3) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = torch.tensor((val_b,)) + + def _helper(m, fn): + m_results = self._try_fn(m, a, b) + fn_result = self._try_fn(fn, a, b) + + if isinstance(m_results, Exception): + self.assertTrue(isinstance(fn_result, Exception)) + else: + for result in m_results: + self.assertEqual(result, fn_result) + + _helper(v3_module, historic_div) + _helper(current_module, torch.div) + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_tensor_inplace(self): + def historic_div_(self, other): + if self.is_floating_point() or other.is_floating_point(): + return self.true_divide_(other) + return self.divide_(other, rounding_mode='trunc') + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a, b): + a /= b + return a + + try: + v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div + self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::div", current_module, 1) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = torch.tensor((val_b,)) + + def _helper(m, fn): + fn_result = self._try_fn(fn, a.clone(), b) + m_result = self._try_fn(m, a, b) + + if isinstance(m_result, Exception): + self.assertTrue(fn_result, Exception) + else: + self.assertEqual(m_result, fn_result) + self.assertEqual(m_result, a) + + _helper(v3_module, historic_div_) + + # Recreates a since it was modified in place + a = torch.tensor((val_a,)) + _helper(current_module, torch.Tensor.div_) + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_tensor_out(self): + def historic_div_out(self, other, out): + if self.is_floating_point() or other.is_floating_point() or out.is_floating_point(): + return torch.true_divide(self, other, out=out) + return torch.divide(self, other, out=out, rounding_mode='trunc') + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a, b, out): + return a.div(b, out=out) + + try: + v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div + self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::div", current_module, 1) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = torch.tensor((val_b,)) + + for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)): + def _helper(m, fn): + fn_result = None + if fn is torch.div: + fn_result = self._try_fn(fn, a, b, out=out.clone()) + else: + fn_result = self._try_fn(fn, a, b, out.clone()) + m_result = self._try_fn(m, a, b, out) + + if isinstance(m_result, Exception): + self.assertTrue(fn_result, Exception) + else: + self.assertEqual(m_result, fn_result) + self.assertEqual(m_result, out) + + _helper(v3_module, historic_div_out) + _helper(current_module, torch.div) + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_scalar(self): + def historic_div_scalar_float(self, other: float): + return torch.true_divide(self, other) + + def historic_div_scalar_int(self, other: int): + if self.is_floating_point(): + return torch.true_divide(self, other) + return torch.divide(self, other, rounding_mode='trunc') + + class MyModuleFloat(torch.nn.Module): + def __init__(self): + super(MyModuleFloat, self).__init__() + + def forward(self, a, b: float): + return a / b + + class MyModuleInt(torch.nn.Module): + def __init__(self): + super(MyModuleInt, self).__init__() + + def forward(self, a, b: int): + return a / b + + try: + v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt") + v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + for m in (v3_module_float, v3_module_int): + self._verify_count("aten::div", m, 2) # true_divide and divide alias to div + self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument + + current_module_float = self._save_load_module(MyModuleFloat) + current_module_int = self._save_load_module(MyModuleInt) + + for m in (current_module_float, current_module_int): + self._verify_count("aten::div", m, 1) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = val_b + + def _helper(m, fn): + m_result = self._try_fn(m, a, b) + fn_result = self._try_fn(fn, a, b) + + if isinstance(m_result, Exception): + self.assertTrue(fn_result, Exception) + else: + self.assertEqual(m_result, fn_result) + + if isinstance(b, float): + _helper(v3_module_float, historic_div_scalar_float) + _helper(current_module_float, torch.div) + else: + _helper(v3_module_int, historic_div_scalar_int) + _helper(current_module_int, torch.div) + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_scalar_reciprocal(self): + def historic_div_scalar_float_reciprocal(self, other: float): + return other / self + + def historic_div_scalar_int_reciprocal(self, other: int): + if self.is_floating_point(): + return other / self + return torch.divide(other, self, rounding_mode='trunc') + + class MyModuleFloat(torch.nn.Module): + def __init__(self): + super(MyModuleFloat, self).__init__() + + def forward(self, a, b: float): + return b / a + + class MyModuleInt(torch.nn.Module): + def __init__(self): + super(MyModuleInt, self).__init__() + + def forward(self, a, b: int): + return b / a + + try: + v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt") + v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + # NOTE: number / tensor is rewritten to torch.reciprocal(a) * b + # so true_divide and floor_divide do not appear in their graphs + for m in (v3_module_float, v3_module_int): + self._verify_no("aten::div", m) + self._verify_no("aten::true_divide", m) + self._verify_no("aten::floor_divide", m) + self._verify_count("aten::reciprocal", m, 1) + + current_module_float = self._save_load_module(MyModuleFloat) + current_module_int = self._save_load_module(MyModuleInt) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = val_b + + def _helper(m, fn): + m_result = self._try_fn(m, a, b) + fn_result = None + # Reverses argument order for torch.div + if fn is torch.div: + fn_result = self._try_fn(torch.div, b, a) + else: + fn_result = self._try_fn(fn, a, b) + + if isinstance(m_result, Exception): + self.assertTrue(isinstance(fn_result, Exception)) + elif fn is torch.div or a.is_floating_point(): + self.assertEqual(m_result, fn_result) + else: + # Skip when fn is not torch.div and a is integral because + # historic_div_scalar_int performs floored division + pass + + if isinstance(b, float): + _helper(v3_module_float, historic_div_scalar_float_reciprocal) + _helper(current_module_float, torch.div) + else: + _helper(v3_module_int, historic_div_scalar_int_reciprocal) + _helper(current_module_int, torch.div) + + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_scalar_inplace(self): + def historic_div_scalar_float_inplace(self, other: float): + return self.true_divide_(other) + + def historic_div_scalar_int_inplace(self, other: int): + if self.is_floating_point(): + return self.true_divide_(other) + + return self.divide_(other, rounding_mode='trunc') + + class MyModuleFloat(torch.nn.Module): + def __init__(self): + super(MyModuleFloat, self).__init__() + + def forward(self, a, b: float): + a /= b + return a + + class MyModuleInt(torch.nn.Module): + def __init__(self): + super(MyModuleInt, self).__init__() + + def forward(self, a, b: int): + a /= b + return a + + try: + v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt") + v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + for m in (v3_module_float, v3_module_int): + self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div + self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument + + current_module_float = self._save_load_module(MyModuleFloat) + current_module_int = self._save_load_module(MyModuleInt) + + for m in (current_module_float, current_module_int): + self._verify_count("aten::div", m, 1) + + for m in (current_module_float, current_module_int): + self._verify_count("aten::div", m, 1) + + vals = (2., 3., 2, 3) + for val_a, val_b in product(vals, vals): + a = torch.tensor((val_a,)) + b = val_b + + def _helper(m, fn): + m_result = self._try_fn(m, a, b) + fn_result = self._try_fn(fn, a, b) + + if isinstance(m_result, Exception): + self.assertTrue(fn_result, Exception) + else: + self.assertEqual(m_result, fn_result) + + if isinstance(b, float): + _helper(v3_module_float, historic_div_scalar_float_inplace) + _helper(current_module_float, torch.Tensor.div_) + else: + _helper(v3_module_int, historic_div_scalar_int_inplace) + _helper(current_module_int, torch.Tensor.div_) + + # NOTE: Scalar division was already true division in op version 3, + # so this test verifies the behavior is unchanged. + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_div_scalar_scalar(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a: float, b: int, c: float, d: int): + result_0 = a / b + result_1 = a / c + result_2 = b / c + result_3 = b / d + return (result_0, result_1, result_2, result_3) + + try: + v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::div", v3_module, 4) + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::div", current_module, 4) + + def _helper(m, fn): + vals = (5., 3, 2., 7) + m_result = m(*vals) + fn_result = fn(*vals) + for mr, hr in zip(m_result, fn_result): + self.assertEqual(mr, hr) + + _helper(v3_module, current_module) + + # NOTE: the JIT was incapable of handling boolean fill values when + # PyTorch produced file format versions 0-4 + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_full_integer_value(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, int_fill: int): + size = torch.Size(2, 2) + a = torch.full(size, int_fill) + b = torch.full(size, 1) + return (a, b) + + try: + v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::full", v4_module, 2) + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::full", current_module, 2) + + # Verifies historic integer type inference is float + # NOTE: only verifies floating point, not exact dtype, due to + # https://github.com/pytorch/pytorch/issues/40470 + results = v4_module(2) + for result in results: + self.assertTrue(result.is_floating_point()) + + # Verifies values are correct + a, b = results + self.assertTrue((a == 2.).all()) + self.assertTrue((b == 1.).all()) + + # Tests that torch.full behavior which is the same from prior versions + # to version 5 is preserved. + # NOTE: while torch.full in eager PyTorch accepts a requires_grad argument, + # it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363) + @unittest.skipIf(torch._C._is_upgraders_enabled(), reason) + def test_versioned_full_preserved(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, float_fill: float): + size = (2, 2) + a = torch.full(size, 1.) + b = torch.full(size, float_fill) + c = torch.full(size, float_fill, dtype=torch.long) + + out = torch.empty(size, dtype=torch.long) + d = torch.full(size, float_fill, out=out) + + e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None, + layout=torch.strided, device='cpu') + return (a, b, c, d, e) + + try: + v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt") + except Exception as e: + self.skipTest("Failed to load fixture!") + + self._verify_count("aten::full", v4_module, 5) + + current_module = self._save_load_module(MyModule) + self._verify_count("aten::full", current_module, 5) + + self.assertEqual(v4_module(2.), current_module(2.)) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 6752835c664..fbc1443024c 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -24,23 +24,6 @@ if __name__ == "__main__": ) class TestSaveLoad(JitTestCase): - - def test_versioned_symbols_reserialization(self): - """ - Tests that loading and saving serialized Torchscript with a versioned - symbol won't persist the original function and will inline the - versioned builtin. - """ - module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt") - buffer = io.BytesIO() - torch.jit.save(module_v2, buffer) - buffer.seek(0) - module_reserialized = torch.jit.load(buffer) - - subcmul_nodes = sum("subcmul" in n.kind() for - n in module_reserialized.graph.nodes()) - self.assertEqual(subcmul_nodes, 0) - def test_different_modules(self): """ Exercise the situation where we have the same qualified name diff --git a/test/jit/test_save_load_for_op_version.py b/test/jit/test_save_load_for_op_version.py index cb134c86bc7..279ec9e779f 100644 --- a/test/jit/test_save_load_for_op_version.py +++ b/test/jit/test_save_load_for_op_version.py @@ -3,7 +3,6 @@ from itertools import product as product import io import os -import random import sys import hypothesis.strategies as st from hypothesis import example, settings, given @@ -24,55 +23,6 @@ if __name__ == "__main__": ) class TestSaveLoadForOpVersion(JitTestCase): - def test_versioned_symbols(self): - """ - Tests Torchscript symbol versioning. See note [Versioned Symbols]. - This test uses an undocumented, test-only function - torch._test_serialization_subcmul. - - This function is implemented as (a - alpha * b) with a default value - of 1 for alpha. In file format version 2, however, it was implemented - as (b - alpha * a) with a default value of 2 for alpha. - This test verifies a module seralized with file format version 2 - exhibits the old behavior, and that the same module newly serialized - exhibits the current behavior. - """ - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - - def forward(self, a, b, alpha: float): - no_alpha = torch._test_serialization_subcmul(a, b) - with_alpha = torch._test_serialization_subcmul(a, b, alpha) - return no_alpha, with_alpha - - def historic_subcmul(a, b, alpha=2): - return b - alpha * a - - def current_subcmul(a, b, alpha=1): - return a - alpha * b - - # Loads and verifies the historic behavior of the module - # that was serialized with version 2 - module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt") - a = torch.randn((5,)) - b = torch.randn((5,)) - alpha = random.random() - args = (a, b, alpha) - no_alpha_v2, with_alpha_v2 = module_v2(*args) - self.assertEqual(no_alpha_v2, historic_subcmul(a, b)) - self.assertEqual(with_alpha_v2, historic_subcmul(*args)) - - # Scripts, saves, loads and verifies the current behavior of the module - scripted_module = torch.jit.script(MyModule()) - buffer = io.BytesIO() - torch.jit.save(scripted_module, buffer) - buffer.seek(0) - module_current = torch.jit.load(buffer) - no_alpha_current, with_alpha_current = module_current(*args) - self.assertEqual(no_alpha_current, current_subcmul(a, b)) - self.assertEqual(with_alpha_current, current_subcmul(*args)) - # Helper that returns the module after saving and loading def _save_load_module(self, m): scripted_module = torch.jit.script(m()) @@ -107,7 +57,6 @@ class TestSaveLoadForOpVersion(JitTestCase): Tests that verify Torchscript remaps aten::div(_) from versions 0-3 to call either aten::true_divide(_), if an input is a float type, or truncated aten::divide(_) otherwise. - NOTE: currently compares against current div behavior, too, since div behavior has not yet been updated. """ @@ -137,18 +86,12 @@ class TestSaveLoadForOpVersion(JitTestCase): # Loads historic module try: - v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt") v3_mobile_module = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl") except Exception as e: self.skipTest("Failed to load fixture!") - self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div - self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument - - current_module = self._save_load_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule) - self._verify_count("aten::div", current_module, 3) for val_a, val_b in product(sample_input, sample_input): a = torch.tensor((val_a,)) @@ -164,9 +107,7 @@ class TestSaveLoadForOpVersion(JitTestCase): for result in m_results: self.assertEqual(result, fn_result) - _helper(v3_module, historic_div) _helper(v3_mobile_module, historic_div) - _helper(current_module, torch.div) _helper(current_mobile_module, torch.div) @settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @@ -189,18 +130,12 @@ class TestSaveLoadForOpVersion(JitTestCase): return a try: - v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt") v3_mobile_module = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl") except Exception as e: self.skipTest("Failed to load fixture!") - self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div - self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument - - current_module = self._save_load_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule) - self._verify_count("aten::div", current_module, 1) for val_a, val_b in product(sample_input, sample_input): a = torch.tensor((val_a,)) @@ -215,12 +150,10 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertEqual(m_result, fn_result) self.assertEqual(m_result, a) - _helper(v3_module, historic_div_) _helper(v3_mobile_module, historic_div_) # Recreates a since it was modified in place a = torch.tensor((val_a,)) - _helper(current_module, torch.Tensor.div_) _helper(current_mobile_module, torch.Tensor.div_) @settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @@ -242,18 +175,12 @@ class TestSaveLoadForOpVersion(JitTestCase): return a.div(b, out=out) try: - v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt") v3_mobile_module = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl") except Exception as e: self.skipTest("Failed to load fixture!") - self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div - self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument - - current_module = self._save_load_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule) - self._verify_count("aten::div", current_module, 1) for val_a, val_b in product(sample_input, sample_input): a = torch.tensor((val_a,)) @@ -274,8 +201,6 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertEqual(m_result, fn_result) self.assertEqual(m_result, out) - _helper(v3_module, historic_div_out) - _helper(current_module, torch.div) _helper(v3_mobile_module, historic_div_out) _helper(current_mobile_module, torch.div) @@ -308,8 +233,6 @@ class TestSaveLoadForOpVersion(JitTestCase): return a / b try: - v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt") - v3_module_int = torch.jit.load(pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v3.pt") v3_mobile_module_float = _load_for_lite_interpreter( pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl") v3_mobile_module_int = _load_for_lite_interpreter( @@ -321,14 +244,9 @@ class TestSaveLoadForOpVersion(JitTestCase): self._verify_count("aten::div", m, 2) # true_divide and divide alias to div self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument - current_module_float = self._save_load_module(MyModuleFloat) - current_module_int = self._save_load_module(MyModuleInt) current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) - for m in (current_module_float, current_module_int): - self._verify_count("aten::div", m, 1) - for val_a, val_b in product(sample_input, sample_input): a = torch.tensor((val_a,)) b = val_b @@ -343,13 +261,9 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertEqual(m_result, fn_result) if isinstance(b, float): - _helper(v3_module_float, historic_div_scalar_float) - _helper(current_module_float, torch.div) _helper(v3_mobile_module_float, current_mobile_module_float) _helper(current_mobile_module_float, torch.div) else: - _helper(v3_module_int, historic_div_scalar_int) - _helper(current_module_int, torch.div) _helper(v3_mobile_module_int, historic_div_scalar_int) _helper(current_mobile_module_int, torch.div) @@ -382,8 +296,6 @@ class TestSaveLoadForOpVersion(JitTestCase): return b / a try: - v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt") - v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt") v3_mobile_module_float = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl") v3_mobile_module_int = _load_for_lite_interpreter( @@ -391,17 +303,6 @@ class TestSaveLoadForOpVersion(JitTestCase): except Exception as e: self.skipTest("Failed to load fixture!") - # NOTE: number / tensor is rewritten to torch.reciprocal(a) * b - # so true_divide and floor_divide do not appear in their graphs - for m in (v3_module_float, v3_module_int): - self._verify_no("aten::div", m) - self._verify_no("aten::true_divide", m) - self._verify_no("aten::floor_divide", m) - self._verify_count("aten::reciprocal", m, 1) - - current_module_float = self._save_load_module(MyModuleFloat) - current_module_int = self._save_load_module(MyModuleInt) - current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) @@ -428,13 +329,9 @@ class TestSaveLoadForOpVersion(JitTestCase): pass if isinstance(b, float): - _helper(v3_module_float, historic_div_scalar_float_reciprocal) - _helper(current_module_float, torch.div) _helper(v3_mobile_module_float, current_mobile_module_float) _helper(current_mobile_module_float, torch.div) else: - _helper(v3_module_int, historic_div_scalar_int_reciprocal) - _helper(current_module_int, torch.div) _helper(v3_mobile_module_int, current_mobile_module_int) _helper(current_mobile_module_int, torch.div) @@ -470,9 +367,6 @@ class TestSaveLoadForOpVersion(JitTestCase): return a try: - v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt") - v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt") - v3_mobile_module_float = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl") v3_mobile_module_int = _load_for_lite_interpreter( @@ -480,22 +374,9 @@ class TestSaveLoadForOpVersion(JitTestCase): except Exception as e: self.skipTest("Failed to load fixture!") - for m in (v3_module_float, v3_module_int): - self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div - self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument - - current_module_float = self._save_load_module(MyModuleFloat) - current_module_int = self._save_load_module(MyModuleInt) - current_mobile_module_float = self._save_load_module(MyModuleFloat) current_mobile_module_int = self._save_load_module(MyModuleInt) - for m in (current_module_float, current_module_int): - self._verify_count("aten::div", m, 1) - - for m in (current_module_float, current_module_int): - self._verify_count("aten::div", m, 1) - for val_a, val_b in product(sample_input, sample_input): a = torch.tensor((val_a,)) b = val_b @@ -510,12 +391,8 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertEqual(m_result, fn_result) if isinstance(b, float): - _helper(v3_module_float, historic_div_scalar_float_inplace) - _helper(current_module_float, torch.Tensor.div_) _helper(current_mobile_module_float, torch.Tensor.div_) else: - _helper(v3_module_int, historic_div_scalar_int_inplace) - _helper(current_module_int, torch.Tensor.div_) _helper(current_mobile_module_int, torch.Tensor.div_) # NOTE: Scalar division was already true division in op version 3, @@ -533,17 +410,12 @@ class TestSaveLoadForOpVersion(JitTestCase): return (result_0, result_1, result_2, result_3) try: - v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt") v3_mobile_module = _load_for_lite_interpreter( pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl") except Exception as e: self.skipTest("Failed to load fixture!") - self._verify_count("aten::div", v3_module, 4) - - current_module = self._save_load_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule) - self._verify_count("aten::div", current_module, 4) def _helper(m, fn): vals = (5., 3, 2., 7) @@ -552,74 +424,4 @@ class TestSaveLoadForOpVersion(JitTestCase): for mr, hr in zip(m_result, fn_result): self.assertEqual(mr, hr) - _helper(v3_module, current_module) _helper(v3_mobile_module, current_mobile_module) - - # NOTE: the JIT was incapable of handling boolean fill values when - # PyTorch produced file format versions 0-4 - def test_versioned_full_integer_value(self): - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - - def forward(self, int_fill: int): - size = torch.Size(2, 2) - a = torch.full(size, int_fill) - b = torch.full(size, 1) - return (a, b) - - try: - v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt") - except Exception as e: - self.skipTest("Failed to load fixture!") - - self._verify_count("aten::full", v4_module, 2) - - current_module = self._save_load_module(MyModule) - self._verify_count("aten::full", current_module, 2) - - # Verifies historic integer type inference is float - # NOTE: only verifies floating point, not exact dtype, due to - # https://github.com/pytorch/pytorch/issues/40470 - results = v4_module(2) - for result in results: - self.assertTrue(result.is_floating_point()) - - # Verifies values are correct - a, b = results - self.assertTrue((a == 2.).all()) - self.assertTrue((b == 1.).all()) - - # Tests that torch.full behavior which is the same from prior versions - # to version 5 is preserved. - # NOTE: while torch.full in eager PyTorch accepts a requires_grad argument, - # it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363) - def test_versioned_full_preserved(self): - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - - def forward(self, float_fill: float): - size = (2, 2) - a = torch.full(size, 1.) - b = torch.full(size, float_fill) - c = torch.full(size, float_fill, dtype=torch.long) - - out = torch.empty(size, dtype=torch.long) - d = torch.full(size, float_fill, out=out) - - e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None, - layout=torch.strided, device='cpu') - return (a, b, c, d, e) - - try: - v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt") - except Exception as e: - self.skipTest("Failed to load fixture!") - - self._verify_count("aten::full", v4_module, 5) - - current_module = self._save_load_module(MyModule) - self._verify_count("aten::full", current_module, 5) - - self.assertEqual(v4_module(2.), current_module(2.)) diff --git a/test/jit/test_upgraders.py b/test/jit/test_upgraders.py index aaea3539f60..edfacf601d8 100644 --- a/test/jit/test_upgraders.py +++ b/test/jit/test_upgraders.py @@ -4,6 +4,11 @@ import io import os import sys import torch +import unittest +import zipfile +from torch.testing import FileCheck +from torch._C import _is_upgraders_enabled +from typing import Union # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -16,6 +21,15 @@ if __name__ == '__main__': "instead.") class TestUpgraders(JitTestCase): + def _load_model_version(self, loaded_model): + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + zipped_model = zipfile.ZipFile(buffer) + version = int(zipped_model.read('archive/version').decode("utf-8")) + return version + + # TODO (tugsuu) We should ideally be generating this test cases. def test_populated_upgrader_graph(self): @torch.jit.script def f(): @@ -67,7 +81,7 @@ class TestUpgraders(JitTestCase): # upgrader map should have populated now upgraders_size = torch._C._get_upgraders_map_size() - test_map = {"a": "b", "c": "d"} + test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())} torch._C._test_only_populate_upgraders(test_map) upgraders_size_after_test = torch._C._get_upgraders_map_size() self.assertEqual(upgraders_size_after_test - upgraders_size, 2) @@ -81,3 +95,117 @@ class TestUpgraders(JitTestCase): upgraders_dump_after_remove_test = torch._C._dump_upgraders_map() self.assertTrue("a" not in upgraders_dump_after_remove_test) self.assertTrue("c" not in upgraders_dump_after_remove_test) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_div_tensor_at_3(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt" + loaded_model = torch.jit.load(model_path) + # there are 3 aten::div in this model + # And the upgrader for aten::div uses two + # div's because of if/else branch + FileCheck().check("prim::If").run(loaded_model.graph) + FileCheck().check_count("aten::div", 6).run(loaded_model.graph) + + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 4) + loaded_model_twice = torch.jit.load(buffer) + # we check by its code because graph variable names + # can be different every time + self.assertEqual(loaded_model.code, loaded_model_twice.code) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_test_serialization(self): + model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt" + + # add test version entry to the version map + upgrader_bumped_version = 3 + upgrader_name = "_test_serialization_subcmul_0_2" + upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" + dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema) + + torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry) + + # add test upgrader in the upgraders map + @torch.jit.script + def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor: + return other - (self * alpha) + torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)}) + + # test if the server is able to find the test upgraders and apply to IR + loaded_model = torch.jit.load(model_path) + FileCheck().check_count("aten::mul", 2).run(loaded_model.graph) + FileCheck().check_count("aten::sub", 2).run(loaded_model.graph) + + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 3) + loaded_model_twice = torch.jit.load(buffer) + # we check by its' code because graph variable names + # can be different every time + self.assertEqual(loaded_model.code, loaded_model_twice.code) + torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul") + torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)}) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_div_scalar_at_3(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt" + loaded_model = torch.jit.load(model_path) + FileCheck().check("prim::If").run(loaded_model.graph) + FileCheck().check_count("aten::div", 2).run(loaded_model.graph) + + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 4) + loaded_model_twice = torch.jit.load(buffer) + + self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0), + loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0)) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_div_tensor_out_at_3(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt" + loaded_model = torch.jit.load(model_path) + FileCheck().check("prim::If").run(loaded_model.graph) + FileCheck().check_count("aten::div", 2).run(loaded_model.graph) + + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 4) + loaded_model_twice = torch.jit.load(buffer) + # we check by its' code because graph variable names + # can be different every time + self.assertEqual(loaded_model.code, loaded_model_twice.code) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_full_at_4(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt" + loaded_model = torch.jit.load(model_path) + FileCheck().check_count("aten::Float", 1).run(loaded_model.graph) + FileCheck().check_count("aten::full", 2).run(loaded_model.graph) + + buffer = io.BytesIO() + torch.jit.save(loaded_model, buffer) + buffer.seek(0) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 5) + loaded_model_twice = torch.jit.load(buffer) + # we check by its' code because graph variable names + # can be different every time + self.assertEqual(loaded_model.code, loaded_model_twice.code) + + @unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled") + def test_aten_full_out_at_4(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt" + loaded_model = torch.jit.load(model_path) + FileCheck().check_count("aten::full", 5).run(loaded_model.graph) + version = self._load_model_version(loaded_model) + self.assertTrue(version == 5) diff --git a/test/test_jit.py b/test/test_jit.py index ffa614ff660..5ccfcd5da85 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -68,6 +68,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401 from jit.test_aten_pow import TestAtenPow # noqa: F401 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 from jit.test_union import TestUnion # noqa: F401 +from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401 from jit.test_models import MnistNet from jit.test_batch_mm import TestBatchMM # noqa: F401 from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index ef01ddf2180..ecf62c4ff88 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -112,6 +112,7 @@ core_sources_common = [ "torch/csrc/jit/mobile/type_parser.cpp", "torch/csrc/jit/mobile/runtime_compatibility.cpp", "torch/csrc/jit/operator_upgraders/version_map.cpp", + "torch/csrc/jit/operator_upgraders/upgraders_guard.cpp", "torch/csrc/jit/runtime/instruction.cpp", "torch/csrc/jit/runtime/jit_exception.cpp", "torch/csrc/jit/runtime/operator.cpp", @@ -121,7 +122,6 @@ core_sources_common = [ "torch/csrc/jit/runtime/vararg_functions.cpp", "torch/csrc/jit/mobile/promoted_prim_ops.cpp", "torch/csrc/jit/mobile/prim_ops_registery.cpp", - "torch/csrc/jit/operator_upgraders/upgraders.cpp", "torch/csrc/profiler/util.cpp", ] @@ -209,6 +209,8 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/mobile/nnc/context.cpp", "torch/csrc/jit/mobile/nnc/registry.cpp", "torch/csrc/jit/operator_upgraders/utils.cpp", + "torch/csrc/jit/operator_upgraders/upgraders_entry.cpp", + "torch/csrc/jit/operator_upgraders/upgraders.cpp", "torch/csrc/jit/passes/annotate_warns.cpp", "torch/csrc/jit/passes/bailout_graph.cpp", "torch/csrc/jit/passes/batch_mm.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f77803c9a7c..e2994993cc8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -293,7 +293,7 @@ def _create_function_from_trace( def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... -def _populate_upgraders_map(content: Dict[str, str]) -> None: ... +def _is_upgraders_enabled() -> _bool: ... def _get_upgraders_map_size() -> _int: ... def _dump_upgraders_map() -> Dict[str, str]: ... def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index dca669f2f23..5a2c4ba25dd 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -85,6 +86,7 @@ def __contains__(self: str, key: str): return self.find(key, 0, len(self)) != -1 )SCRIPT"; +#if !ENABLE_UPGRADERS // Implementations of historic symbol behaviors are defined here // See note [Versioned Symbols] @@ -158,7 +160,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None, pin_memory:Optional[bool]=None) -> Tensor: if dtype is None: fill_value = float(fill_value) - return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) )SCRIPT"; @@ -168,6 +169,7 @@ auto full_out = R"SCRIPT( def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor: return torch.full(size, fill_value, out=out) )SCRIPT"; +#endif struct BuiltinFunctionRegistry { const std::vector& getAllBuiltinFunctionsFor(Symbol name) { @@ -237,6 +239,7 @@ struct BuiltinFunctionRegistry { loadSource(aten_ops, "aten"); loadSource(aten_ops_additional, "aten"); +#if !ENABLE_UPGRADERS // Loads functions implementing historic behavior, see note [Versioned // Symbols] // Note: these functions go into the "upgraders" namespace @@ -249,6 +252,7 @@ struct BuiltinFunctionRegistry { loadSource(div__scalar, "upgraders"); loadSource(full, "upgraders"); loadSource(full_out, "upgraders"); +#endif // These are under `prim` instead of `aten` since they exist to bind certain // tensor property getters to correpsonding methods diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 9cc51bc6155..20cab7c7499 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -656,6 +658,13 @@ struct to_ir { } method.setSchema(emitDef(def, self, graph->block())); +#if ENABLE_UPGRADERS + // At this point, we might have received a graph that is compiled with + // old operator schemas that might not exist in the system anymore. + // Therefore, we replace such ops with its' valid upgrader. + ReplaceOldOperatorsWithUpgraders(graph); +#endif + // NB ORDERING: SSA conversion has to occur before // lifting of closures and forks, this way closures are converted // to SSA while part of their original graph, and closures are ready to diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 5219791d140..7081a433365 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -469,7 +470,7 @@ static c10::optional tryMatchSchema( } // construct the full name of the schema for easier look up - auto schema_name = schema.operator_name().name + "." + schema.overload_name(); + auto schema_name = getFullSchemaName(schema); return MatchedSchema{ std::move(positional_inputs), @@ -619,6 +620,13 @@ static Value* emitBuiltinNode( return packOutputs(graph, n->outputs(), matched_schema.return_field_names); } +std::string getFullSchemaName(const ::c10::FunctionSchema& schema) { + if (schema.overload_name() != "") { + return schema.operator_name().name + "." + schema.overload_name(); + } + return schema.operator_name().name; +} + // Search for operators matching the provided symbol name and input types. // If one is found, emit a node to the graph for that operator. Value* emitBuiltinCall( @@ -631,8 +639,12 @@ Value* emitBuiltinCall( const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); +#if ENABLE_UPGRADERS // first let's set the graph's version auto graph_version = graph.get_op_version(); +#else + c10::optional graph_version = c10::nullopt; +#endif std::stringstream failure_messages; std::vector schemas; @@ -643,9 +655,7 @@ Value* emitBuiltinCall( schemas.reserve(variants.size()); for (const std::shared_ptr& op : variants) { bool found_upgrader = false; - auto op_overload_name = op->schema().overload_name(); - auto op_name = op->schema().operator_name().name + - ((op_overload_name != "") ? "." + op_overload_name : ""); + auto op_name = getFullSchemaName(op->schema()); if (graph_version.has_value()) { auto version_entry = get_operator_version_map().find(op_name); if (version_entry != get_operator_version_map().end()) { diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index f969c3e212a..71f03e398f4 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -41,6 +41,8 @@ TORCH_API bool convertibleToList( const TypePtr& type, const TypePtr& list_type_); +TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema); + TORCH_API Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 91a2d3e4fbf..f6a3f72a59d 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -319,12 +320,14 @@ struct TORCH_API BuiltinModule : public SugaredValue { } auto sym = Symbol::fromQualString(name + "::" + field); +#if !ENABLE_UPGRADERS if (version.has_value()) { // Possibly replaces symbol with another that implements its // historic behavior. // See note [Versioned Symbols] sym = get_symbol_for_version(sym, *version); } +#endif return std::make_shared(sym, c10::nullopt); } diff --git a/torch/csrc/jit/frontend/versioned_symbols.cpp b/torch/csrc/jit/frontend/versioned_symbols.cpp index d7ee0d3393a..564bdf15c4b 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.cpp +++ b/torch/csrc/jit/frontend/versioned_symbols.cpp @@ -1,12 +1,13 @@ #include +#include #include #include namespace torch { namespace jit { - +#if !ENABLE_UPGRADERS // Note [Versioned Symbols] // When the schema or behavior of a symbol changes, serialized Torchscript // programs using that symbol are likely to break. To prevent those breaks, @@ -105,6 +106,7 @@ uint64_t get_min_version_for_kind(const NodeKind& kind) { return it->second; } +#endif } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/versioned_symbols.h b/torch/csrc/jit/frontend/versioned_symbols.h index 1c59708afd7..cc9f8b354b8 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.h +++ b/torch/csrc/jit/frontend/versioned_symbols.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -7,7 +8,7 @@ namespace torch { namespace jit { - +#if !ENABLE_UPGRADERS // Maps the given symbol into an implementation of its behavior at the // given version. // See note [Versioned Symbols] @@ -17,6 +18,6 @@ get_symbol_for_version(const Symbol name, const uint64_t version); // Maps the given kind to the minimum version that supports it. // See note [Dynamic Versions and torch.jit.save vs. torch.save] TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind); - +#endif } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/operator_upgraders/upgraders.cpp b/torch/csrc/jit/operator_upgraders/upgraders.cpp index 67c2957692e..3acbad8bbbe 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders.cpp @@ -1,4 +1,7 @@ #include + +#include +#include #include #include #include @@ -9,12 +12,13 @@ namespace jit { static UpgradersMap upgradersMap; void UpgradersMap::set_content( - std::unordered_map&& content) { + std::unordered_map>&& content) { // make sure we populate the map only once std::lock_guard _(lock); if (isPopulated) { return; } + content_ = std::move(content); isPopulated = true; } @@ -24,7 +28,12 @@ int UpgradersMap::count() { return content_.size(); } -const std::unordered_map& UpgradersMap:: +bool UpgradersMap::is_populated() { + std::lock_guard _(lock); + return isPopulated; +} + +const std::unordered_map>& UpgradersMap:: get_content() { std::lock_guard _(lock); return content_; @@ -34,7 +43,9 @@ void UpgradersMap::test_only_set_content( const std::unordered_map& content) { std::lock_guard _(lock); for (const auto& entry : content) { - content_.insert(entry); + auto graph = std::make_shared(); + torch::jit::parseIR(entry.second, graph.get()); + content_.insert(std::make_pair(entry.first, graph)); } } void UpgradersMap::test_only_remove_content( @@ -46,7 +57,7 @@ void UpgradersMap::test_only_remove_content( } void populate_upgraders_map( - std::unordered_map&& content) { + std::unordered_map>&& content) { upgradersMap.set_content(std::move(content)); } @@ -54,7 +65,12 @@ int get_upgraders_map_size() { return upgradersMap.count(); } -const std::unordered_map& dump_upgraders_map() { +bool is_upgraders_map_populated() { + return upgradersMap.is_populated(); +} + +const std::unordered_map>& +dump_upgraders_map() { return upgradersMap.get_content(); } diff --git a/torch/csrc/jit/operator_upgraders/upgraders.h b/torch/csrc/jit/operator_upgraders/upgraders.h index d96e48793e9..a6c4b81b15f 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders.h +++ b/torch/csrc/jit/operator_upgraders/upgraders.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -10,9 +11,11 @@ namespace jit { class UpgradersMap { public: - void set_content(std::unordered_map&& content); + void set_content( + std::unordered_map>&& content); int count(); - const std::unordered_map& get_content(); + const std::unordered_map>& get_content(); + bool is_populated(); // THESE METHODS ARE ONLY USED FOR TESTING PURPOSES void test_only_set_content( const std::unordered_map& content); @@ -20,17 +23,19 @@ class UpgradersMap { const std::unordered_map& content); private: - std::unordered_map content_; + std::unordered_map> content_; std::mutex lock; bool isPopulated = false; }; TORCH_API void populate_upgraders_map( - std::unordered_map&& content); + std::unordered_map>&& content); TORCH_API int get_upgraders_map_size(); -TORCH_API const std::unordered_map& +TORCH_API bool is_upgraders_map_populated(); + +TORCH_API const std::unordered_map>& dump_upgraders_map(); // THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp new file mode 100644 index 00000000000..b4efbdbe032 --- /dev/null +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -0,0 +1,78 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +static std::unordered_map kUpgradersEntryMap( + {{"div_Tensor_0_3", R"SCRIPT( +def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor: + if (self.is_floating_point() or other.is_floating_point()): + return self.true_divide(other) + return self.divide(other, rounding_mode='trunc') +)SCRIPT"}, + {"div_Scalar_0_3", R"SCRIPT( +def div_Scalar_0_3(self: Tensor, other: number) -> Tensor: + if (self.is_floating_point() or isinstance(other, float)): + return self.true_divide(other) + return self.divide(other, rounding_mode='trunc') +)SCRIPT"}, + {"div_out_0_3", R"SCRIPT( +def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor: + if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()): + return self.true_divide(other, out=out) + return self.divide(other, rounding_mode='trunc', out=out) +)SCRIPT"}, + {"div__Tensor_0_3", R"SCRIPT( +def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor: + if (self.is_floating_point() or other.is_floating_point()): + return self.true_divide_(other) + return self.divide_(other, rounding_mode='trunc') +)SCRIPT"}, + {"div__Scalar_0_3", R"SCRIPT( +def div__Scalar_0_3(self: Tensor, other: number) -> Tensor: + if (self.is_floating_point() or isinstance(other, float)): + return self.true_divide_(other) + return self.divide_(other, rounding_mode='trunc') +)SCRIPT"}, + {"full_0_4", R"SCRIPT( +def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None, + layout:Optional[int]=None, device:Optional[Device]=None, + pin_memory:Optional[bool]=None) -> Tensor: + if dtype is None: + fill_value = float(fill_value) + return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) +)SCRIPT"}, + {"full_out_0_4", R"SCRIPT( +def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor: + return torch.full(size, fill_value, out=out) +)SCRIPT"}}); + +using UpgraderMap = std::unordered_map>; +void populate_upgraders_graph_map() { + if (!is_upgraders_map_populated()) { + UpgraderMap populate_content; + for (const auto& entry : kUpgradersEntryMap) { + auto cu = std::make_shared(); + cu->define(c10::nullopt, entry.second, nativeResolver(), nullptr); + Function& jitFunc = cu->get_function(entry.first); + GraphFunction& graphFunction = toGraphFunction(jitFunc); + populate_content.insert( + std::make_pair(entry.first, graphFunction.graph())); + } + + populate_upgraders_map(std::forward(populate_content)); + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.h b/torch/csrc/jit/operator_upgraders/upgraders_entry.h new file mode 100644 index 00000000000..bb3d7f5e355 --- /dev/null +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +TORCH_API void populate_upgraders_graph_map(); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/operator_upgraders/upgraders_guard.cpp b/torch/csrc/jit/operator_upgraders/upgraders_guard.cpp new file mode 100644 index 00000000000..02364c5eee8 --- /dev/null +++ b/torch/csrc/jit/operator_upgraders/upgraders_guard.cpp @@ -0,0 +1,12 @@ +#include +#include + +namespace torch { +namespace jit { + +bool is_upgraders_enabled() { + return ENABLE_UPGRADERS; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/operator_upgraders/upgraders_guard.h b/torch/csrc/jit/operator_upgraders/upgraders_guard.h new file mode 100644 index 00000000000..9907f8238c0 --- /dev/null +++ b/torch/csrc/jit/operator_upgraders/upgraders_guard.h @@ -0,0 +1,10 @@ +#pragma once +#include + +namespace torch { +namespace jit { + +TORCH_API bool is_upgraders_enabled(); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/operator_upgraders/version_map.cpp b/torch/csrc/jit/operator_upgraders/version_map.cpp index 5f07b82f0e2..94e996f9e0d 100644 --- a/torch/csrc/jit/operator_upgraders/version_map.cpp +++ b/torch/csrc/jit/operator_upgraders/version_map.cpp @@ -31,11 +31,11 @@ static std::unordered_map> operatorVersi {"aten::div_.Tensor", {{4, "div__Tensor_0_3", - "aten::div_.Tensor(Tensor(a!), Tensor other) -> Tensor(a!)"}}}, + "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}}, {"aten::div_.Scalar", {{4, "div__Scalar_0_3", - "aten::div_.Scalar(Tensor(a!), Tensor other) -> Tensor(a!)"}}}, + "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}}, {"aten::full", {{5, "full_0_4", diff --git a/torch/csrc/jit/passes/replacement_of_old_operators.cpp b/torch/csrc/jit/passes/replacement_of_old_operators.cpp index 5950a06a3f8..5cf8e43b195 100644 --- a/torch/csrc/jit/passes/replacement_of_old_operators.cpp +++ b/torch/csrc/jit/passes/replacement_of_old_operators.cpp @@ -15,29 +15,6 @@ namespace torch { namespace jit { -static std::unordered_map> upgraderCache; - -std::shared_ptr getUpgraderGraph(const std::string& upgrader_name) { - auto it = upgraderCache.find(upgrader_name); - if (it != upgraderCache.end()) { - return it->second; - } - - auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name); - TORCH_INTERNAL_ASSERT( - upgrader_graph_entry != dump_upgraders_map().end(), - "Corresponding upgrader graph for ", - upgrader_name, - " must exist.", - " This upgrader" - " might be deprecated."); - - auto upgrader_graph = std::make_shared(); - parseIR(upgrader_graph_entry->second, upgrader_graph.get()); - upgraderCache[upgrader_name] = upgrader_graph; - return upgrader_graph; -} - struct OldOpsReplacerWithUpgraders { OldOpsReplacerWithUpgraders(std::shared_ptr graph) : graph_(std::move(graph)) {} @@ -52,9 +29,7 @@ struct OldOpsReplacerWithUpgraders { Node* node = graph_it.next(); while (node) { if (auto schema = node->maybeSchema()) { - auto schema_name = schema->name() + - (schema->overload_name() != "" ? "." + schema->overload_name() - : ""); + auto schema_name = getFullSchemaName(*schema); // this implies there was a version bump because of this operator auto version_entry = get_operator_version_map().find(schema_name); if (version_entry != get_operator_version_map().end()) { @@ -74,8 +49,16 @@ struct OldOpsReplacerWithUpgraders { } auto upgrader_entry_val = upgrader_entry.value(); auto upgrader_name = upgrader_entry_val.upgrader_name; - auto upgrader_graph = getUpgraderGraph(upgrader_name); + auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name); + TORCH_INTERNAL_ASSERT( + upgrader_graph_entry != dump_upgraders_map().end(), + "Corresponding upgrader graph for ", + upgrader_name, + " must exist.", + " This upgrader" + " might be deprecated."); + auto upgrader_graph = upgrader_graph_entry->second; // inline the upgrader function body WithInsertPoint guard(node); auto new_outputs = insertGraph( diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 4f7e5541c15..4bb9d050830 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -1731,7 +1732,8 @@ void initJitScriptBindings(PyObject* module) { return Decl(p.parseTypeComment()); }); - m.def("_populate_upgraders_map", &populate_upgraders_map); + m.def("_is_upgraders_enabled", &is_upgraders_enabled); + m.def("_get_upgraders_map_size", &get_upgraders_map_size); m.def("_dump_upgraders_map", &dump_upgraders_map); diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 86aa6e3909e..6a999bd3864 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -10,6 +10,7 @@ #endif #include #include +#include #include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include @@ -239,6 +241,10 @@ graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_ze Module ScriptModuleDeserializer::deserialize( c10::optional device, ExtraFilesMap& extra_files) { + // we populate the upgraders map before any load starts +#if ENABLE_UPGRADERS + populate_upgraders_graph_map(); +#endif C10_LOG_API_USAGE_ONCE("torch.script.load"); device_ = device; // Load extra files. diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp index 03eff0fc0c3..d1fb898730e 100644 --- a/torch/csrc/jit/serialization/import_source.cpp +++ b/torch/csrc/jit/serialization/import_source.cpp @@ -577,7 +577,9 @@ void SourceImporterImpl::importClass( /*propResolvers=*/{}, methods, method_resolvers, - &self); + &self, + /*shouldMangle=*/false, + /*operator_set_version=*/version_); cu_->define_hooks( qualified_classname, hooks, diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index b994e1ab998..f3e5f6e2ab9 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include @@ -741,9 +743,22 @@ struct PythonPrintImpl { } } - void checkVersion(const Node* const node) { + void checkVersion(Node* node) { +#if ENABLE_UPGRADERS + if (auto schema = node->maybeSchema()) { + auto schema_name = getFullSchemaName(*schema); + auto version_entry = get_operator_version_map().find(schema_name); + if (version_entry != get_operator_version_map().end()) { + const auto& entry = version_entry->second; + // TODO (tugsuu) move this calculation into a seperate step. + min_version_ = std::max( + min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version)); + } + } +#else min_version_ = std::max(min_version_, get_min_version_for_kind(node->kind())); +#endif } void printNode(Node* node, bool print_const) { @@ -1594,7 +1609,11 @@ struct PythonPrintImpl { bool enforce_importable_; // The least version that supports all printed ops +#if ENABLE_UPGRADERS + uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion; +#else uint64_t min_version_ = 0; +#endif }; PythonPrint::PythonPrint( diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 99156905ca7..f2c32c3a19b 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -147,11 +147,6 @@ def load(f, map_location=None, _extra_files=None): os.remove("scriptmodule.pt") """ - # TODO (tugsuu) Putting this on top level creates - # circular import issue. We need to fix this later - from torch.jit.operator_upgraders import populate_upgraders_map - populate_upgraders_map() - if isinstance(f, string_classes): if not os.path.exists(f): # type: ignore[type-var] raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe] diff --git a/torch/jit/operator_upgraders.py b/torch/jit/operator_upgraders.py index 7d2bcfd16f2..fa6d96b9a59 100644 --- a/torch/jit/operator_upgraders.py +++ b/torch/jit/operator_upgraders.py @@ -116,12 +116,5 @@ def generate_bytecode() -> List: yaml_content.append(entry) return yaml_content -def populate_upgraders_map(): - upgrader_set = collect_available_upgraders() - content = {} - for upgrader_name in upgrader_set: - content[upgrader_name] = str(globals()[upgrader_name].graph) - torch._C._populate_upgraders_map(content) - if __name__ == "__main__": raise RuntimeError("This file is not meant to be run directly")