Bump version number to 7 and compile old operators with old schema (#68358)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68358

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33433730

Pulled By: tugsbayasgalan

fbshipit-source-id: 202c58365bae13195d3545cefcb0da9162b02151
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2022-01-05 23:55:49 -08:00 committed by Facebook GitHub Bot
parent 8bdbe94344
commit b0fdca8855
32 changed files with 993 additions and 281 deletions

View File

@ -164,7 +164,13 @@ class TORCH_API PyTorchStreamWriter final {
std::string padding_; std::string padding_;
std::ofstream file_stream_; std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_; std::function<size_t(const void*, size_t)> writer_func_;
// This number will be updated when the model has operators
// that have valid upgraders.
#if ENABLE_UPGRADERS
uint64_t version_ = kMinProducedFileFormatVersion;
#else
uint64_t version_ = kProducedFileFormatVersion; uint64_t version_ = kProducedFileFormatVersion;
#endif
bool finalized_ = false; bool finalized_ = false;
bool err_seen_ = false; bool err_seen_ = false;
friend size_t ostream_write_func( friend size_t ostream_write_func(

View File

@ -1,12 +1,21 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
namespace caffe2 { namespace caffe2 {
namespace serialize { namespace serialize {
// Flag that controls if we want to enable upgraders
// in the server side. When this flag is set to False,
// it will switch to old dynamic versioning approach
#define ENABLE_UPGRADERS false
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
// Versions (i.e. why was the version number bumped?) // Versions (i.e. why was the version number bumped?)
@ -47,7 +56,23 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// 5. (Dynamic) Stops torch.full inferring a floating point dtype // 5. (Dynamic) Stops torch.full inferring a floating point dtype
// when given bool or integer fill values. // when given bool or integer fill values.
// 6. Write version string to `./data/version` instead of `version`. // 6. Write version string to `./data/version` instead of `version`.
#if ENABLE_UPGRADERS
// This is set to 7 from 3 due to a different interpretation of what
// file format version is. Whenever there is new upgrader introduced,
// this number should be bumped.
// 1. aten::div is changed at version 4
// 2. aten::full is changed at version 5
// 3. torch.package uses version 6
constexpr uint64_t kProducedFileFormatVersion = 0x7L;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L; constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif
// Absolute minimum version we will write packages. This
// means that every package from now on will always be
// greater than this number.
constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
// The version we write when the archive contains bytecode. // The version we write when the archive contains bytecode.
// It must be higher or eq to kProducedFileFormatVersion. // It must be higher or eq to kProducedFileFormatVersion.

View File

@ -42,6 +42,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_alias_analysis.cpp ${JIT_TEST_ROOT}/test_alias_analysis.cpp
${JIT_TEST_ROOT}/test_argument_spec.cpp ${JIT_TEST_ROOT}/test_argument_spec.cpp
${JIT_TEST_ROOT}/test_autodiff.cpp ${JIT_TEST_ROOT}/test_autodiff.cpp
${JIT_TEST_ROOT}/test_load_upgraders.cpp
${JIT_TEST_ROOT}/test_op_replacement.cpp ${JIT_TEST_ROOT}/test_op_replacement.cpp
${JIT_TEST_ROOT}/test_upgrader_utils.cpp ${JIT_TEST_ROOT}/test_upgrader_utils.cpp
${JIT_TEST_ROOT}/test_backend.cpp ${JIT_TEST_ROOT}/test_backend.cpp

View File

@ -0,0 +1,45 @@
#include <caffe2/serialize/versions.h>
#include <gtest/gtest.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/serialization/import.h>
#include <test/cpp/jit/test_utils.h>
namespace torch {
namespace jit {
#if ENABLE_UPGRADERS
// Basic tests to check if C++ torch::jit::load
// can load the upgraders fine
// TODO (tugsuu) add more tests
TEST(UpgraderLoad, CanPopulateUpgradersGraph) {
Module m("m");
m.define(R"(
def forward(self, x: Tensor):
b = 5
return torch.div(x, b)
)");
std::stringstream ms;
m.save(ms);
auto loaded_m = torch::jit::load(ms);
auto version_map = get_operator_version_map();
auto upgraders = dump_upgraders_map();
for (const auto& entry : version_map) {
auto list_of_upgraders_for_op = entry.second;
for (const auto& upgrader_entry : list_of_upgraders_for_op) {
EXPECT_TRUE(
upgraders.find(upgrader_entry.upgrader_name) != upgraders.end());
}
}
auto test_graph = loaded_m.get_method("forward").graph();
// should have saved with version 4, so it is still up to date
testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph);
}
#endif
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,553 @@
# Owner(s): ["oncall: jit"]
from itertools import product as product
import io
import os
import random
import sys
import unittest
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Legacy test cases for dynamic operator versioning
class TestLegacyUpgraders(JitTestCase):
reason: str = "Skipped due to new global operator versioning for operators"
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_symbols(self):
"""
Tests Torchscript symbol versioning. See note [Versioned Symbols].
This test uses an undocumented, test-only function
torch._test_serialization_subcmul.
This function is implemented as (a - alpha * b) with a default value
of 1 for alpha. In file format version 2, however, it was implemented
as (b - alpha * a) with a default value of 2 for alpha.
This test verifies a module seralized with file format version 2
exhibits the old behavior, and that the same module newly serialized
exhibits the current behavior.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, alpha: float):
no_alpha = torch._test_serialization_subcmul(a, b)
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
return no_alpha, with_alpha
def historic_subcmul(a, b, alpha=2):
return b - alpha * a
def current_subcmul(a, b, alpha=1):
return a - alpha * b
# Loads and verifies the historic behavior of the module
# that was serialized with version 2
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
a = torch.randn((5,))
b = torch.randn((5,))
alpha = random.random()
args = (a, b, alpha)
no_alpha_v2, with_alpha_v2 = module_v2(*args)
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
# Scripts, saves, loads and verifies the current behavior of the module
scripted_module = torch.jit.script(MyModule())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
module_current = torch.jit.load(buffer)
no_alpha_current, with_alpha_current = module_current(*args)
self.assertEqual(no_alpha_current, current_subcmul(a, b))
self.assertEqual(with_alpha_current, current_subcmul(*args))
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
scripted_module = torch.jit.script(m())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
# Helper which returns the result of a function or the exception the
# function threw.
def _try_fn(self, fn, *args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return e
def _verify_no(self, kind, m):
self._verify_count(kind, m, 0)
def _verify_count(self, kind, m, count):
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count)
"""
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
or truncated aten::divide(_) otherwise.
NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated.
"""
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor(self):
def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
# Tensor x Tensor
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
result_0 = a / b
result_1 = torch.div(a, b)
result_2 = a.div(b)
return result_0, result_1, result_2
# Loads historic module
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 3)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
m_results = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_results, Exception):
self.assertTrue(isinstance(fn_result, Exception))
else:
for result in m_results:
self.assertEqual(result, fn_result)
_helper(v3_module, historic_div)
_helper(current_module, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor_inplace(self):
def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
a /= b
return a
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
fn_result = self._try_fn(fn, a.clone(), b)
m_result = self._try_fn(m, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, a)
_helper(v3_module, historic_div_)
# Recreates a since it was modified in place
a = torch.tensor((val_a,))
_helper(current_module, torch.Tensor.div_)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor_out(self):
def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
return torch.true_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, out):
return a.div(b, out=out)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
def _helper(m, fn):
fn_result = None
if fn is torch.div:
fn_result = self._try_fn(fn, a, b, out=out.clone())
else:
fn_result = self._try_fn(fn, a, b, out.clone())
m_result = self._try_fn(m, a, b, out)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, out)
_helper(v3_module, historic_div_out)
_helper(current_module, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar(self):
def historic_div_scalar_float(self, other: float):
return torch.true_divide(self, other)
def historic_div_scalar_int(self, other: int):
if self.is_floating_point():
return torch.true_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return a / b
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return a / b
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int)
_helper(current_module_int, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_reciprocal(self):
def historic_div_scalar_float_reciprocal(self, other: float):
return other / self
def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point():
return other / self
return torch.divide(other, self, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return b / a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return b / a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
# so true_divide and floor_divide do not appear in their graphs
for m in (v3_module_float, v3_module_int):
self._verify_no("aten::div", m)
self._verify_no("aten::true_divide", m)
self._verify_no("aten::floor_divide", m)
self._verify_count("aten::reciprocal", m, 1)
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = None
# Reverses argument order for torch.div
if fn is torch.div:
fn_result = self._try_fn(torch.div, b, a)
else:
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(isinstance(fn_result, Exception))
elif fn is torch.div or a.is_floating_point():
self.assertEqual(m_result, fn_result)
else:
# Skip when fn is not torch.div and a is integral because
# historic_div_scalar_int performs floored division
pass
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
_helper(current_module_int, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_inplace(self):
def historic_div_scalar_float_inplace(self, other: float):
return self.true_divide_(other)
def historic_div_scalar_int_inplace(self, other: int):
if self.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
a /= b
return a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
a /= b
return a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_inplace)
_helper(current_module_float, torch.Tensor.div_)
else:
_helper(v3_module_int, historic_div_scalar_int_inplace)
_helper(current_module_int, torch.Tensor.div_)
# NOTE: Scalar division was already true division in op version 3,
# so this test verifies the behavior is unchanged.
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_scalar(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a: float, b: int, c: float, d: int):
result_0 = a / b
result_1 = a / c
result_2 = b / c
result_3 = b / d
return (result_0, result_1, result_2, result_3)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 4)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 4)
def _helper(m, fn):
vals = (5., 3, 2., 7)
m_result = m(*vals)
fn_result = fn(*vals)
for mr, hr in zip(m_result, fn_result):
self.assertEqual(mr, hr)
_helper(v3_module, current_module)
# NOTE: the JIT was incapable of handling boolean fill values when
# PyTorch produced file format versions 0-4
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_full_integer_value(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, int_fill: int):
size = torch.Size(2, 2)
a = torch.full(size, int_fill)
b = torch.full(size, 1)
return (a, b)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 2)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 2)
# Verifies historic integer type inference is float
# NOTE: only verifies floating point, not exact dtype, due to
# https://github.com/pytorch/pytorch/issues/40470
results = v4_module(2)
for result in results:
self.assertTrue(result.is_floating_point())
# Verifies values are correct
a, b = results
self.assertTrue((a == 2.).all())
self.assertTrue((b == 1.).all())
# Tests that torch.full behavior which is the same from prior versions
# to version 5 is preserved.
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_full_preserved(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, float_fill: float):
size = (2, 2)
a = torch.full(size, 1.)
b = torch.full(size, float_fill)
c = torch.full(size, float_fill, dtype=torch.long)
out = torch.empty(size, dtype=torch.long)
d = torch.full(size, float_fill, out=out)
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
layout=torch.strided, device='cpu')
return (a, b, c, d, e)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 5)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 5)
self.assertEqual(v4_module(2.), current_module(2.))

View File

@ -24,23 +24,6 @@ if __name__ == "__main__":
) )
class TestSaveLoad(JitTestCase): class TestSaveLoad(JitTestCase):
def test_versioned_symbols_reserialization(self):
"""
Tests that loading and saving serialized Torchscript with a versioned
symbol won't persist the original function and will inline the
versioned builtin.
"""
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
buffer = io.BytesIO()
torch.jit.save(module_v2, buffer)
buffer.seek(0)
module_reserialized = torch.jit.load(buffer)
subcmul_nodes = sum("subcmul" in n.kind() for
n in module_reserialized.graph.nodes())
self.assertEqual(subcmul_nodes, 0)
def test_different_modules(self): def test_different_modules(self):
""" """
Exercise the situation where we have the same qualified name Exercise the situation where we have the same qualified name

View File

@ -3,7 +3,6 @@
from itertools import product as product from itertools import product as product
import io import io
import os import os
import random
import sys import sys
import hypothesis.strategies as st import hypothesis.strategies as st
from hypothesis import example, settings, given from hypothesis import example, settings, given
@ -24,55 +23,6 @@ if __name__ == "__main__":
) )
class TestSaveLoadForOpVersion(JitTestCase): class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_symbols(self):
"""
Tests Torchscript symbol versioning. See note [Versioned Symbols].
This test uses an undocumented, test-only function
torch._test_serialization_subcmul.
This function is implemented as (a - alpha * b) with a default value
of 1 for alpha. In file format version 2, however, it was implemented
as (b - alpha * a) with a default value of 2 for alpha.
This test verifies a module seralized with file format version 2
exhibits the old behavior, and that the same module newly serialized
exhibits the current behavior.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, alpha: float):
no_alpha = torch._test_serialization_subcmul(a, b)
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
return no_alpha, with_alpha
def historic_subcmul(a, b, alpha=2):
return b - alpha * a
def current_subcmul(a, b, alpha=1):
return a - alpha * b
# Loads and verifies the historic behavior of the module
# that was serialized with version 2
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
a = torch.randn((5,))
b = torch.randn((5,))
alpha = random.random()
args = (a, b, alpha)
no_alpha_v2, with_alpha_v2 = module_v2(*args)
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
# Scripts, saves, loads and verifies the current behavior of the module
scripted_module = torch.jit.script(MyModule())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
module_current = torch.jit.load(buffer)
no_alpha_current, with_alpha_current = module_current(*args)
self.assertEqual(no_alpha_current, current_subcmul(a, b))
self.assertEqual(with_alpha_current, current_subcmul(*args))
# Helper that returns the module after saving and loading # Helper that returns the module after saving and loading
def _save_load_module(self, m): def _save_load_module(self, m):
scripted_module = torch.jit.script(m()) scripted_module = torch.jit.script(m())
@ -107,7 +57,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
Tests that verify Torchscript remaps aten::div(_) from versions 0-3 Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type, to call either aten::true_divide(_), if an input is a float type,
or truncated aten::divide(_) otherwise. or truncated aten::divide(_) otherwise.
NOTE: currently compares against current div behavior, too, since NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated. div behavior has not yet been updated.
""" """
@ -137,18 +86,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
# Loads historic module # Loads historic module
try: try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 3)
for val_a, val_b in product(sample_input, sample_input): for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
@ -164,9 +107,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
for result in m_results: for result in m_results:
self.assertEqual(result, fn_result) self.assertEqual(result, fn_result)
_helper(v3_module, historic_div)
_helper(v3_mobile_module, historic_div) _helper(v3_mobile_module, historic_div)
_helper(current_module, torch.div)
_helper(current_mobile_module, torch.div) _helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@ -189,18 +130,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a return a
try: try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 1)
for val_a, val_b in product(sample_input, sample_input): for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
@ -215,12 +150,10 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result) self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, a) self.assertEqual(m_result, a)
_helper(v3_module, historic_div_)
_helper(v3_mobile_module, historic_div_) _helper(v3_mobile_module, historic_div_)
# Recreates a since it was modified in place # Recreates a since it was modified in place
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
_helper(current_module, torch.Tensor.div_)
_helper(current_mobile_module, torch.Tensor.div_) _helper(current_mobile_module, torch.Tensor.div_)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@ -242,18 +175,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a.div(b, out=out) return a.div(b, out=out)
try: try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 1)
for val_a, val_b in product(sample_input, sample_input): for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
@ -274,8 +201,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result) self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, out) self.assertEqual(m_result, out)
_helper(v3_module, historic_div_out)
_helper(current_module, torch.div)
_helper(v3_mobile_module, historic_div_out) _helper(v3_mobile_module, historic_div_out)
_helper(current_mobile_module, torch.div) _helper(current_mobile_module, torch.div)
@ -308,8 +233,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a / b return a / b
try: try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl") pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
@ -321,14 +244,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for val_a, val_b in product(sample_input, sample_input): for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
b = val_b b = val_b
@ -343,13 +261,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result) self.assertEqual(m_result, fn_result)
if isinstance(b, float): if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float)
_helper(current_module_float, torch.div)
_helper(v3_mobile_module_float, current_mobile_module_float) _helper(v3_mobile_module_float, current_mobile_module_float)
_helper(current_mobile_module_float, torch.div) _helper(current_mobile_module_float, torch.div)
else: else:
_helper(v3_module_int, historic_div_scalar_int)
_helper(current_module_int, torch.div)
_helper(v3_mobile_module_int, historic_div_scalar_int) _helper(v3_mobile_module_int, historic_div_scalar_int)
_helper(current_mobile_module_int, torch.div) _helper(current_mobile_module_int, torch.div)
@ -382,8 +296,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return b / a return b / a
try: try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
@ -391,17 +303,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
# so true_divide and floor_divide do not appear in their graphs
for m in (v3_module_float, v3_module_int):
self._verify_no("aten::div", m)
self._verify_no("aten::true_divide", m)
self._verify_no("aten::floor_divide", m)
self._verify_count("aten::reciprocal", m, 1)
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
@ -428,13 +329,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
pass pass
if isinstance(b, float): if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
_helper(current_module_float, torch.div)
_helper(v3_mobile_module_float, current_mobile_module_float) _helper(v3_mobile_module_float, current_mobile_module_float)
_helper(current_mobile_module_float, torch.div) _helper(current_mobile_module_float, torch.div)
else: else:
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
_helper(current_module_int, torch.div)
_helper(v3_mobile_module_int, current_mobile_module_int) _helper(v3_mobile_module_int, current_mobile_module_int)
_helper(current_mobile_module_int, torch.div) _helper(current_mobile_module_int, torch.div)
@ -470,9 +367,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a return a
try: try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
@ -480,22 +374,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_module(MyModuleFloat) current_mobile_module_float = self._save_load_module(MyModuleFloat)
current_mobile_module_int = self._save_load_module(MyModuleInt) current_mobile_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for val_a, val_b in product(sample_input, sample_input): for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
b = val_b b = val_b
@ -510,12 +391,8 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result) self.assertEqual(m_result, fn_result)
if isinstance(b, float): if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_inplace)
_helper(current_module_float, torch.Tensor.div_)
_helper(current_mobile_module_float, torch.Tensor.div_) _helper(current_mobile_module_float, torch.Tensor.div_)
else: else:
_helper(v3_module_int, historic_div_scalar_int_inplace)
_helper(current_module_int, torch.Tensor.div_)
_helper(current_mobile_module_int, torch.Tensor.div_) _helper(current_mobile_module_int, torch.Tensor.div_)
# NOTE: Scalar division was already true division in op version 3, # NOTE: Scalar division was already true division in op version 3,
@ -533,17 +410,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return (result_0, result_1, result_2, result_3) return (result_0, result_1, result_2, result_3)
try: try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl") pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 4)
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 4)
def _helper(m, fn): def _helper(m, fn):
vals = (5., 3, 2., 7) vals = (5., 3, 2., 7)
@ -552,74 +424,4 @@ class TestSaveLoadForOpVersion(JitTestCase):
for mr, hr in zip(m_result, fn_result): for mr, hr in zip(m_result, fn_result):
self.assertEqual(mr, hr) self.assertEqual(mr, hr)
_helper(v3_module, current_module)
_helper(v3_mobile_module, current_mobile_module) _helper(v3_mobile_module, current_mobile_module)
# NOTE: the JIT was incapable of handling boolean fill values when
# PyTorch produced file format versions 0-4
def test_versioned_full_integer_value(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, int_fill: int):
size = torch.Size(2, 2)
a = torch.full(size, int_fill)
b = torch.full(size, 1)
return (a, b)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 2)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 2)
# Verifies historic integer type inference is float
# NOTE: only verifies floating point, not exact dtype, due to
# https://github.com/pytorch/pytorch/issues/40470
results = v4_module(2)
for result in results:
self.assertTrue(result.is_floating_point())
# Verifies values are correct
a, b = results
self.assertTrue((a == 2.).all())
self.assertTrue((b == 1.).all())
# Tests that torch.full behavior which is the same from prior versions
# to version 5 is preserved.
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
def test_versioned_full_preserved(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, float_fill: float):
size = (2, 2)
a = torch.full(size, 1.)
b = torch.full(size, float_fill)
c = torch.full(size, float_fill, dtype=torch.long)
out = torch.empty(size, dtype=torch.long)
d = torch.full(size, float_fill, out=out)
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
layout=torch.strided, device='cpu')
return (a, b, c, d, e)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 5)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 5)
self.assertEqual(v4_module(2.), current_module(2.))

View File

@ -4,6 +4,11 @@ import io
import os import os
import sys import sys
import torch import torch
import unittest
import zipfile
from torch.testing import FileCheck
from torch._C import _is_upgraders_enabled
from typing import Union
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -16,6 +21,15 @@ if __name__ == '__main__':
"instead.") "instead.")
class TestUpgraders(JitTestCase): class TestUpgraders(JitTestCase):
def _load_model_version(self, loaded_model):
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
version = int(zipped_model.read('archive/version').decode("utf-8"))
return version
# TODO (tugsuu) We should ideally be generating this test cases.
def test_populated_upgrader_graph(self): def test_populated_upgrader_graph(self):
@torch.jit.script @torch.jit.script
def f(): def f():
@ -67,7 +81,7 @@ class TestUpgraders(JitTestCase):
# upgrader map should have populated now # upgrader map should have populated now
upgraders_size = torch._C._get_upgraders_map_size() upgraders_size = torch._C._get_upgraders_map_size()
test_map = {"a": "b", "c": "d"} test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
torch._C._test_only_populate_upgraders(test_map) torch._C._test_only_populate_upgraders(test_map)
upgraders_size_after_test = torch._C._get_upgraders_map_size() upgraders_size_after_test = torch._C._get_upgraders_map_size()
self.assertEqual(upgraders_size_after_test - upgraders_size, 2) self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
@ -81,3 +95,117 @@ class TestUpgraders(JitTestCase):
upgraders_dump_after_remove_test = torch._C._dump_upgraders_map() upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
self.assertTrue("a" not in upgraders_dump_after_remove_test) self.assertTrue("a" not in upgraders_dump_after_remove_test)
self.assertTrue("c" not in upgraders_dump_after_remove_test) self.assertTrue("c" not in upgraders_dump_after_remove_test)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_tensor_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
loaded_model = torch.jit.load(model_path)
# there are 3 aten::div in this model
# And the upgrader for aten::div uses two
# div's because of if/else branch
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
# we check by its code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_test_serialization(self):
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
# add test version entry to the version map
upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
# add test upgrader in the upgraders map
@torch.jit.script
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
return other - (self * alpha)
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
# test if the server is able to find the test upgraders and apply to IR
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 3)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_scalar_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_tensor_out_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 5)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_out_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 5)

View File

@ -68,6 +68,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401
from jit.test_aten_pow import TestAtenPow # noqa: F401 from jit.test_aten_pow import TestAtenPow # noqa: F401
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
from jit.test_union import TestUnion # noqa: F401 from jit.test_union import TestUnion # noqa: F401
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
from jit.test_models import MnistNet from jit.test_models import MnistNet
from jit.test_batch_mm import TestBatchMM # noqa: F401 from jit.test_batch_mm import TestBatchMM # noqa: F401
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401

View File

@ -112,6 +112,7 @@ core_sources_common = [
"torch/csrc/jit/mobile/type_parser.cpp", "torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/mobile/runtime_compatibility.cpp", "torch/csrc/jit/mobile/runtime_compatibility.cpp",
"torch/csrc/jit/operator_upgraders/version_map.cpp", "torch/csrc/jit/operator_upgraders/version_map.cpp",
"torch/csrc/jit/operator_upgraders/upgraders_guard.cpp",
"torch/csrc/jit/runtime/instruction.cpp", "torch/csrc/jit/runtime/instruction.cpp",
"torch/csrc/jit/runtime/jit_exception.cpp", "torch/csrc/jit/runtime/jit_exception.cpp",
"torch/csrc/jit/runtime/operator.cpp", "torch/csrc/jit/runtime/operator.cpp",
@ -121,7 +122,6 @@ core_sources_common = [
"torch/csrc/jit/runtime/vararg_functions.cpp", "torch/csrc/jit/runtime/vararg_functions.cpp",
"torch/csrc/jit/mobile/promoted_prim_ops.cpp", "torch/csrc/jit/mobile/promoted_prim_ops.cpp",
"torch/csrc/jit/mobile/prim_ops_registery.cpp", "torch/csrc/jit/mobile/prim_ops_registery.cpp",
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
"torch/csrc/profiler/util.cpp", "torch/csrc/profiler/util.cpp",
] ]
@ -209,6 +209,8 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/mobile/nnc/context.cpp", "torch/csrc/jit/mobile/nnc/context.cpp",
"torch/csrc/jit/mobile/nnc/registry.cpp", "torch/csrc/jit/mobile/nnc/registry.cpp",
"torch/csrc/jit/operator_upgraders/utils.cpp", "torch/csrc/jit/operator_upgraders/utils.cpp",
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
"torch/csrc/jit/passes/annotate_warns.cpp", "torch/csrc/jit/passes/annotate_warns.cpp",
"torch/csrc/jit/passes/bailout_graph.cpp", "torch/csrc/jit/passes/bailout_graph.cpp",
"torch/csrc/jit/passes/batch_mm.cpp", "torch/csrc/jit/passes/batch_mm.cpp",

View File

@ -293,7 +293,7 @@ def _create_function_from_trace(
def _jit_is_script_object(obj: Any) -> _bool: ... def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ... def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ... def parse_type_comment(comment: str) -> Decl: ...
def _populate_upgraders_map(content: Dict[str, str]) -> None: ... def _is_upgraders_enabled() -> _bool: ...
def _get_upgraders_map_size() -> _int: ... def _get_upgraders_map_size() -> _int: ...
def _dump_upgraders_map() -> Dict[str, str]: ... def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/frontend/builtin_functions.h> #include <torch/csrc/jit/frontend/builtin_functions.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/api/include/torch/jit.h> #include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/jit/frontend/code_template.h> #include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/frontend/resolver.h> #include <torch/csrc/jit/frontend/resolver.h>
@ -85,6 +86,7 @@ def __contains__(self: str, key: str):
return self.find(key, 0, len(self)) != -1 return self.find(key, 0, len(self)) != -1
)SCRIPT"; )SCRIPT";
#if !ENABLE_UPGRADERS
// Implementations of historic symbol behaviors are defined here // Implementations of historic symbol behaviors are defined here
// See note [Versioned Symbols] // See note [Versioned Symbols]
@ -158,7 +160,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
pin_memory:Optional[bool]=None) -> Tensor: pin_memory:Optional[bool]=None) -> Tensor:
if dtype is None: if dtype is None:
fill_value = float(fill_value) fill_value = float(fill_value)
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
)SCRIPT"; )SCRIPT";
@ -168,6 +169,7 @@ auto full_out = R"SCRIPT(
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor: def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
return torch.full(size, fill_value, out=out) return torch.full(size, fill_value, out=out)
)SCRIPT"; )SCRIPT";
#endif
struct BuiltinFunctionRegistry { struct BuiltinFunctionRegistry {
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) { const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
@ -237,6 +239,7 @@ struct BuiltinFunctionRegistry {
loadSource(aten_ops, "aten"); loadSource(aten_ops, "aten");
loadSource(aten_ops_additional, "aten"); loadSource(aten_ops_additional, "aten");
#if !ENABLE_UPGRADERS
// Loads functions implementing historic behavior, see note [Versioned // Loads functions implementing historic behavior, see note [Versioned
// Symbols] // Symbols]
// Note: these functions go into the "upgraders" namespace // Note: these functions go into the "upgraders" namespace
@ -249,6 +252,7 @@ struct BuiltinFunctionRegistry {
loadSource(div__scalar, "upgraders"); loadSource(div__scalar, "upgraders");
loadSource(full, "upgraders"); loadSource(full, "upgraders");
loadSource(full_out, "upgraders"); loadSource(full_out, "upgraders");
#endif
// These are under `prim` instead of `aten` since they exist to bind certain // These are under `prim` instead of `aten` since they exist to bind certain
// tensor property getters to correpsonding methods // tensor property getters to correpsonding methods

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/StringUtil.h> #include <c10/util/StringUtil.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h> #include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
#include <torch/csrc/jit/frontend/convert_to_ssa.h> #include <torch/csrc/jit/frontend/convert_to_ssa.h>
@ -22,6 +23,7 @@
#include <torch/csrc/jit/passes/lift_closures.h> #include <torch/csrc/jit/passes/lift_closures.h>
#include <torch/csrc/jit/passes/lower_tuples.h> #include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/normalize_ops.h> #include <torch/csrc/jit/passes/normalize_ops.h>
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
#include <torch/csrc/jit/runtime/interpreter.h> #include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h> #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
@ -656,6 +658,13 @@ struct to_ir {
} }
method.setSchema(emitDef(def, self, graph->block())); method.setSchema(emitDef(def, self, graph->block()));
#if ENABLE_UPGRADERS
// At this point, we might have received a graph that is compiled with
// old operator schemas that might not exist in the system anymore.
// Therefore, we replace such ops with its' valid upgrader.
ReplaceOldOperatorsWithUpgraders(graph);
#endif
// NB ORDERING: SSA conversion has to occur before // NB ORDERING: SSA conversion has to occur before
// lifting of closures and forks, this way closures are converted // lifting of closures and forks, this way closures are converted
// to SSA while part of their original graph, and closures are ready to // to SSA while part of their original graph, and closures are ready to

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/frontend/builtin_functions.h> #include <torch/csrc/jit/frontend/builtin_functions.h>
#include <torch/csrc/jit/frontend/error_report.h> #include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h> #include <torch/csrc/jit/frontend/function_schema_parser.h>
@ -469,7 +470,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
} }
// construct the full name of the schema for easier look up // construct the full name of the schema for easier look up
auto schema_name = schema.operator_name().name + "." + schema.overload_name(); auto schema_name = getFullSchemaName(schema);
return MatchedSchema{ return MatchedSchema{
std::move(positional_inputs), std::move(positional_inputs),
@ -619,6 +620,13 @@ static Value* emitBuiltinNode(
return packOutputs(graph, n->outputs(), matched_schema.return_field_names); return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
} }
std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
if (schema.overload_name() != "") {
return schema.operator_name().name + "." + schema.overload_name();
}
return schema.operator_name().name;
}
// Search for operators matching the provided symbol name and input types. // Search for operators matching the provided symbol name and input types.
// If one is found, emit a node to the graph for that operator. // If one is found, emit a node to the graph for that operator.
Value* emitBuiltinCall( Value* emitBuiltinCall(
@ -631,8 +639,12 @@ Value* emitBuiltinCall(
const auto& variants = getAllOperatorsFor(name); const auto& variants = getAllOperatorsFor(name);
const auto& builtin_functions = getAllBuiltinFunctionsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
#if ENABLE_UPGRADERS
// first let's set the graph's version // first let's set the graph's version
auto graph_version = graph.get_op_version(); auto graph_version = graph.get_op_version();
#else
c10::optional<size_t> graph_version = c10::nullopt;
#endif
std::stringstream failure_messages; std::stringstream failure_messages;
std::vector<const FunctionSchema*> schemas; std::vector<const FunctionSchema*> schemas;
@ -643,9 +655,7 @@ Value* emitBuiltinCall(
schemas.reserve(variants.size()); schemas.reserve(variants.size());
for (const std::shared_ptr<Operator>& op : variants) { for (const std::shared_ptr<Operator>& op : variants) {
bool found_upgrader = false; bool found_upgrader = false;
auto op_overload_name = op->schema().overload_name(); auto op_name = getFullSchemaName(op->schema());
auto op_name = op->schema().operator_name().name +
((op_overload_name != "") ? "." + op_overload_name : "");
if (graph_version.has_value()) { if (graph_version.has_value()) {
auto version_entry = get_operator_version_map().find(op_name); auto version_entry = get_operator_version_map().find(op_name);
if (version_entry != get_operator_version_map().end()) { if (version_entry != get_operator_version_map().end()) {

View File

@ -41,6 +41,8 @@ TORCH_API bool convertibleToList(
const TypePtr& type, const TypePtr& type,
const TypePtr& list_type_); const TypePtr& list_type_);
TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema);
TORCH_API Value* emitBuiltinCall( TORCH_API Value* emitBuiltinCall(
const SourceRange& loc, const SourceRange& loc,
Graph& graph, Graph& graph,

View File

@ -5,6 +5,7 @@
#include <utility> #include <utility>
#include <ATen/core/symbol.h> #include <ATen/core/symbol.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/module.h> #include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h> #include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h> #include <torch/csrc/jit/frontend/schema_matching.h>
@ -319,12 +320,14 @@ struct TORCH_API BuiltinModule : public SugaredValue {
} }
auto sym = Symbol::fromQualString(name + "::" + field); auto sym = Symbol::fromQualString(name + "::" + field);
#if !ENABLE_UPGRADERS
if (version.has_value()) { if (version.has_value()) {
// Possibly replaces symbol with another that implements its // Possibly replaces symbol with another that implements its
// historic behavior. // historic behavior.
// See note [Versioned Symbols] // See note [Versioned Symbols]
sym = get_symbol_for_version(sym, *version); sym = get_symbol_for_version(sym, *version);
} }
#endif
return std::make_shared<BuiltinFunction>(sym, c10::nullopt); return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
} }

View File

@ -1,12 +1,13 @@
#include <torch/csrc/jit/frontend/versioned_symbols.h> #include <torch/csrc/jit/frontend/versioned_symbols.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/api/include/torch/jit.h> #include <torch/csrc/api/include/torch/jit.h>
#include <unordered_map> #include <unordered_map>
namespace torch { namespace torch {
namespace jit { namespace jit {
#if !ENABLE_UPGRADERS
// Note [Versioned Symbols] // Note [Versioned Symbols]
// When the schema or behavior of a symbol changes, serialized Torchscript // When the schema or behavior of a symbol changes, serialized Torchscript
// programs using that symbol are likely to break. To prevent those breaks, // programs using that symbol are likely to break. To prevent those breaks,
@ -105,6 +106,7 @@ uint64_t get_min_version_for_kind(const NodeKind& kind) {
return it->second; return it->second;
} }
#endif
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <caffe2/serialize/versions.h>
#include <torch/csrc/Export.h> #include <torch/csrc/Export.h>
#include <torch/csrc/jit/api/module.h> #include <torch/csrc/jit/api/module.h>
@ -7,7 +8,7 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
#if !ENABLE_UPGRADERS
// Maps the given symbol into an implementation of its behavior at the // Maps the given symbol into an implementation of its behavior at the
// given version. // given version.
// See note [Versioned Symbols] // See note [Versioned Symbols]
@ -17,6 +18,6 @@ get_symbol_for_version(const Symbol name, const uint64_t version);
// Maps the given kind to the minimum version that supports it. // Maps the given kind to the minimum version that supports it.
// See note [Dynamic Versions and torch.jit.save vs. torch.save] // See note [Dynamic Versions and torch.jit.save vs. torch.save]
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind); TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
#endif
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -1,4 +1,7 @@
#include <torch/csrc/jit/operator_upgraders/upgraders.h> #include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -9,12 +12,13 @@ namespace jit {
static UpgradersMap upgradersMap; static UpgradersMap upgradersMap;
void UpgradersMap::set_content( void UpgradersMap::set_content(
std::unordered_map<std::string, std::string>&& content) { std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
// make sure we populate the map only once // make sure we populate the map only once
std::lock_guard<std::mutex> _(lock); std::lock_guard<std::mutex> _(lock);
if (isPopulated) { if (isPopulated) {
return; return;
} }
content_ = std::move(content); content_ = std::move(content);
isPopulated = true; isPopulated = true;
} }
@ -24,7 +28,12 @@ int UpgradersMap::count() {
return content_.size(); return content_.size();
} }
const std::unordered_map<std::string, std::string>& UpgradersMap:: bool UpgradersMap::is_populated() {
std::lock_guard<std::mutex> _(lock);
return isPopulated;
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
get_content() { get_content() {
std::lock_guard<std::mutex> _(lock); std::lock_guard<std::mutex> _(lock);
return content_; return content_;
@ -34,7 +43,9 @@ void UpgradersMap::test_only_set_content(
const std::unordered_map<std::string, std::string>& content) { const std::unordered_map<std::string, std::string>& content) {
std::lock_guard<std::mutex> _(lock); std::lock_guard<std::mutex> _(lock);
for (const auto& entry : content) { for (const auto& entry : content) {
content_.insert(entry); auto graph = std::make_shared<Graph>();
torch::jit::parseIR(entry.second, graph.get());
content_.insert(std::make_pair(entry.first, graph));
} }
} }
void UpgradersMap::test_only_remove_content( void UpgradersMap::test_only_remove_content(
@ -46,7 +57,7 @@ void UpgradersMap::test_only_remove_content(
} }
void populate_upgraders_map( void populate_upgraders_map(
std::unordered_map<std::string, std::string>&& content) { std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
upgradersMap.set_content(std::move(content)); upgradersMap.set_content(std::move(content));
} }
@ -54,7 +65,12 @@ int get_upgraders_map_size() {
return upgradersMap.count(); return upgradersMap.count();
} }
const std::unordered_map<std::string, std::string>& dump_upgraders_map() { bool is_upgraders_map_populated() {
return upgradersMap.is_populated();
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map() {
return upgradersMap.get_content(); return upgradersMap.get_content();
} }

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <c10/macros/Export.h> #include <c10/macros/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <iostream> #include <iostream>
#include <mutex> #include <mutex>
#include <string> #include <string>
@ -10,9 +11,11 @@ namespace jit {
class UpgradersMap { class UpgradersMap {
public: public:
void set_content(std::unordered_map<std::string, std::string>&& content); void set_content(
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
int count(); int count();
const std::unordered_map<std::string, std::string>& get_content(); const std::unordered_map<std::string, std::shared_ptr<Graph>>& get_content();
bool is_populated();
// THESE METHODS ARE ONLY USED FOR TESTING PURPOSES // THESE METHODS ARE ONLY USED FOR TESTING PURPOSES
void test_only_set_content( void test_only_set_content(
const std::unordered_map<std::string, std::string>& content); const std::unordered_map<std::string, std::string>& content);
@ -20,17 +23,19 @@ class UpgradersMap {
const std::unordered_map<std::string, std::string>& content); const std::unordered_map<std::string, std::string>& content);
private: private:
std::unordered_map<std::string, std::string> content_; std::unordered_map<std::string, std::shared_ptr<Graph>> content_;
std::mutex lock; std::mutex lock;
bool isPopulated = false; bool isPopulated = false;
}; };
TORCH_API void populate_upgraders_map( TORCH_API void populate_upgraders_map(
std::unordered_map<std::string, std::string>&& content); std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
TORCH_API int get_upgraders_map_size(); TORCH_API int get_upgraders_map_size();
TORCH_API const std::unordered_map<std::string, std::string>& TORCH_API bool is_upgraders_map_populated();
TORCH_API const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map(); dump_upgraders_map();
// THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING // THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING

View File

@ -0,0 +1,78 @@
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <ATen/core/stack.h>
#include <c10/macros/Export.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
static std::unordered_map<std::string, std::string> kUpgradersEntryMap(
{{"div_Tensor_0_3", R"SCRIPT(
def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_Scalar_0_3", R"SCRIPT(
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_out_0_3", R"SCRIPT(
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
return self.true_divide(other, out=out)
return self.divide(other, rounding_mode='trunc', out=out)
)SCRIPT"},
{"div__Tensor_0_3", R"SCRIPT(
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"div__Scalar_0_3", R"SCRIPT(
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"full_0_4", R"SCRIPT(
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
layout:Optional[int]=None, device:Optional[Device]=None,
pin_memory:Optional[bool]=None) -> Tensor:
if dtype is None:
fill_value = float(fill_value)
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
)SCRIPT"},
{"full_out_0_4", R"SCRIPT(
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
return torch.full(size, fill_value, out=out)
)SCRIPT"}});
using UpgraderMap = std::unordered_map<std::string, std::shared_ptr<Graph>>;
void populate_upgraders_graph_map() {
if (!is_upgraders_map_populated()) {
UpgraderMap populate_content;
for (const auto& entry : kUpgradersEntryMap) {
auto cu = std::make_shared<CompilationUnit>();
cu->define(c10::nullopt, entry.second, nativeResolver(), nullptr);
Function& jitFunc = cu->get_function(entry.first);
GraphFunction& graphFunction = toGraphFunction(jitFunc);
populate_content.insert(
std::make_pair(entry.first, graphFunction.graph()));
}
populate_upgraders_map(std::forward<UpgraderMap>(populate_content));
}
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,14 @@
#pragma once
#include <c10/macros/Export.h>
#include <iostream>
#include <mutex>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
TORCH_API void populate_upgraders_graph_map();
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,12 @@
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
namespace torch {
namespace jit {
bool is_upgraders_enabled() {
return ENABLE_UPGRADERS;
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,10 @@
#pragma once
#include <c10/macros/Export.h>
namespace torch {
namespace jit {
TORCH_API bool is_upgraders_enabled();
} // namespace jit
} // namespace torch

View File

@ -31,11 +31,11 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
{"aten::div_.Tensor", {"aten::div_.Tensor",
{{4, {{4,
"div__Tensor_0_3", "div__Tensor_0_3",
"aten::div_.Tensor(Tensor(a!), Tensor other) -> Tensor(a!)"}}}, "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
{"aten::div_.Scalar", {"aten::div_.Scalar",
{{4, {{4,
"div__Scalar_0_3", "div__Scalar_0_3",
"aten::div_.Scalar(Tensor(a!), Tensor other) -> Tensor(a!)"}}}, "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
{"aten::full", {"aten::full",
{{5, {{5,
"full_0_4", "full_0_4",

View File

@ -15,29 +15,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
static std::unordered_map<std::string, std::shared_ptr<Graph>> upgraderCache;
std::shared_ptr<Graph> getUpgraderGraph(const std::string& upgrader_name) {
auto it = upgraderCache.find(upgrader_name);
if (it != upgraderCache.end()) {
return it->second;
}
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
TORCH_INTERNAL_ASSERT(
upgrader_graph_entry != dump_upgraders_map().end(),
"Corresponding upgrader graph for ",
upgrader_name,
" must exist.",
" This upgrader"
" might be deprecated.");
auto upgrader_graph = std::make_shared<Graph>();
parseIR(upgrader_graph_entry->second, upgrader_graph.get());
upgraderCache[upgrader_name] = upgrader_graph;
return upgrader_graph;
}
struct OldOpsReplacerWithUpgraders { struct OldOpsReplacerWithUpgraders {
OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph) OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {} : graph_(std::move(graph)) {}
@ -52,9 +29,7 @@ struct OldOpsReplacerWithUpgraders {
Node* node = graph_it.next(); Node* node = graph_it.next();
while (node) { while (node) {
if (auto schema = node->maybeSchema()) { if (auto schema = node->maybeSchema()) {
auto schema_name = schema->name() + auto schema_name = getFullSchemaName(*schema);
(schema->overload_name() != "" ? "." + schema->overload_name()
: "");
// this implies there was a version bump because of this operator // this implies there was a version bump because of this operator
auto version_entry = get_operator_version_map().find(schema_name); auto version_entry = get_operator_version_map().find(schema_name);
if (version_entry != get_operator_version_map().end()) { if (version_entry != get_operator_version_map().end()) {
@ -74,8 +49,16 @@ struct OldOpsReplacerWithUpgraders {
} }
auto upgrader_entry_val = upgrader_entry.value(); auto upgrader_entry_val = upgrader_entry.value();
auto upgrader_name = upgrader_entry_val.upgrader_name; auto upgrader_name = upgrader_entry_val.upgrader_name;
auto upgrader_graph = getUpgraderGraph(upgrader_name); auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
TORCH_INTERNAL_ASSERT(
upgrader_graph_entry != dump_upgraders_map().end(),
"Corresponding upgrader graph for ",
upgrader_name,
" must exist.",
" This upgrader"
" might be deprecated.");
auto upgrader_graph = upgrader_graph_entry->second;
// inline the upgrader function body // inline the upgrader function body
WithInsertPoint guard(node); WithInsertPoint guard(node);
auto new_outputs = insertGraph( auto new_outputs = insertGraph(

View File

@ -14,6 +14,7 @@
#include <torch/csrc/jit/mobile/model_compatibility.h> #include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h> #include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h> #include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h> #include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/python/module_python.h> #include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h> #include <torch/csrc/jit/python/python_ivalue.h>
@ -1731,7 +1732,8 @@ void initJitScriptBindings(PyObject* module) {
return Decl(p.parseTypeComment()); return Decl(p.parseTypeComment());
}); });
m.def("_populate_upgraders_map", &populate_upgraders_map); m.def("_is_upgraders_enabled", &is_upgraders_enabled);
m.def("_get_upgraders_map_size", &get_upgraders_map_size); m.def("_get_upgraders_map_size", &get_upgraders_map_size);
m.def("_dump_upgraders_map", &dump_upgraders_map); m.def("_dump_upgraders_map", &dump_upgraders_map);

View File

@ -10,6 +10,7 @@
#endif #endif
#include <torch/csrc/jit/frontend/script_type_parser.h> #include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h> #include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/serialization/import_read.h> #include <torch/csrc/jit/serialization/import_read.h>
#include <torch/csrc/jit/serialization/import_source.h> #include <torch/csrc/jit/serialization/import_source.h>
@ -20,6 +21,7 @@
#include <caffe2/serialize/file_adapter.h> #include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h> #include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h> #include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/versions.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -239,6 +241,10 @@ graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_ze
Module ScriptModuleDeserializer::deserialize( Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device, c10::optional<at::Device> device,
ExtraFilesMap& extra_files) { ExtraFilesMap& extra_files) {
// we populate the upgraders map before any load starts
#if ENABLE_UPGRADERS
populate_upgraders_graph_map();
#endif
C10_LOG_API_USAGE_ONCE("torch.script.load"); C10_LOG_API_USAGE_ONCE("torch.script.load");
device_ = device; device_ = device;
// Load extra files. // Load extra files.

View File

@ -577,7 +577,9 @@ void SourceImporterImpl::importClass(
/*propResolvers=*/{}, /*propResolvers=*/{},
methods, methods,
method_resolvers, method_resolvers,
&self); &self,
/*shouldMangle=*/false,
/*operator_set_version=*/version_);
cu_->define_hooks( cu_->define_hooks(
qualified_classname, qualified_classname,
hooks, hooks,

View File

@ -6,6 +6,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/StringUtil.h> #include <c10/util/StringUtil.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h> #include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h> #include <torch/csrc/jit/frontend/error_report.h>
@ -13,6 +14,7 @@
#include <torch/csrc/jit/ir/attributes.h> #include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h> #include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/resource_guard.h> #include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h> #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
@ -741,9 +743,22 @@ struct PythonPrintImpl {
} }
} }
void checkVersion(const Node* const node) { void checkVersion(Node* node) {
#if ENABLE_UPGRADERS
if (auto schema = node->maybeSchema()) {
auto schema_name = getFullSchemaName(*schema);
auto version_entry = get_operator_version_map().find(schema_name);
if (version_entry != get_operator_version_map().end()) {
const auto& entry = version_entry->second;
// TODO (tugsuu) move this calculation into a seperate step.
min_version_ = std::max(
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
}
}
#else
min_version_ = min_version_ =
std::max(min_version_, get_min_version_for_kind(node->kind())); std::max(min_version_, get_min_version_for_kind(node->kind()));
#endif
} }
void printNode(Node* node, bool print_const) { void printNode(Node* node, bool print_const) {
@ -1594,7 +1609,11 @@ struct PythonPrintImpl {
bool enforce_importable_; bool enforce_importable_;
// The least version that supports all printed ops // The least version that supports all printed ops
#if ENABLE_UPGRADERS
uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
#else
uint64_t min_version_ = 0; uint64_t min_version_ = 0;
#endif
}; };
PythonPrint::PythonPrint( PythonPrint::PythonPrint(

View File

@ -147,11 +147,6 @@ def load(f, map_location=None, _extra_files=None):
os.remove("scriptmodule.pt") os.remove("scriptmodule.pt")
""" """
# TODO (tugsuu) Putting this on top level creates
# circular import issue. We need to fix this later
from torch.jit.operator_upgraders import populate_upgraders_map
populate_upgraders_map()
if isinstance(f, string_classes): if isinstance(f, string_classes):
if not os.path.exists(f): # type: ignore[type-var] if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe] raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]

View File

@ -116,12 +116,5 @@ def generate_bytecode() -> List:
yaml_content.append(entry) yaml_content.append(entry)
return yaml_content return yaml_content
def populate_upgraders_map():
upgrader_set = collect_available_upgraders()
content = {}
for upgrader_name in upgrader_set:
content[upgrader_name] = str(globals()[upgrader_name].graph)
torch._C._populate_upgraders_map(content)
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError("This file is not meant to be run directly") raise RuntimeError("This file is not meant to be run directly")