mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Bump version number to 7 and compile old operators with old schema (#68358)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68358 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33433730 Pulled By: tugsbayasgalan fbshipit-source-id: 202c58365bae13195d3545cefcb0da9162b02151
This commit is contained in:
parent
8bdbe94344
commit
b0fdca8855
|
|
@ -164,7 +164,13 @@ class TORCH_API PyTorchStreamWriter final {
|
||||||
std::string padding_;
|
std::string padding_;
|
||||||
std::ofstream file_stream_;
|
std::ofstream file_stream_;
|
||||||
std::function<size_t(const void*, size_t)> writer_func_;
|
std::function<size_t(const void*, size_t)> writer_func_;
|
||||||
|
// This number will be updated when the model has operators
|
||||||
|
// that have valid upgraders.
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
uint64_t version_ = kMinProducedFileFormatVersion;
|
||||||
|
#else
|
||||||
uint64_t version_ = kProducedFileFormatVersion;
|
uint64_t version_ = kProducedFileFormatVersion;
|
||||||
|
#endif
|
||||||
bool finalized_ = false;
|
bool finalized_ = false;
|
||||||
bool err_seen_ = false;
|
bool err_seen_ = false;
|
||||||
friend size_t ostream_write_func(
|
friend size_t ostream_write_func(
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,21 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
||||||
|
// Flag that controls if we want to enable upgraders
|
||||||
|
// in the server side. When this flag is set to False,
|
||||||
|
// it will switch to old dynamic versioning approach
|
||||||
|
#define ENABLE_UPGRADERS false
|
||||||
|
|
||||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||||
|
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L;
|
||||||
|
#else
|
||||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Versions (i.e. why was the version number bumped?)
|
// Versions (i.e. why was the version number bumped?)
|
||||||
|
|
||||||
|
|
@ -47,7 +56,23 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||||
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
|
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
|
||||||
// when given bool or integer fill values.
|
// when given bool or integer fill values.
|
||||||
// 6. Write version string to `./data/version` instead of `version`.
|
// 6. Write version string to `./data/version` instead of `version`.
|
||||||
|
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
// This is set to 7 from 3 due to a different interpretation of what
|
||||||
|
// file format version is. Whenever there is new upgrader introduced,
|
||||||
|
// this number should be bumped.
|
||||||
|
// 1. aten::div is changed at version 4
|
||||||
|
// 2. aten::full is changed at version 5
|
||||||
|
// 3. torch.package uses version 6
|
||||||
|
constexpr uint64_t kProducedFileFormatVersion = 0x7L;
|
||||||
|
#else
|
||||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Absolute minimum version we will write packages. This
|
||||||
|
// means that every package from now on will always be
|
||||||
|
// greater than this number.
|
||||||
|
constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
|
||||||
|
|
||||||
// The version we write when the archive contains bytecode.
|
// The version we write when the archive contains bytecode.
|
||||||
// It must be higher or eq to kProducedFileFormatVersion.
|
// It must be higher or eq to kProducedFileFormatVersion.
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ set(JIT_TEST_SRCS
|
||||||
${JIT_TEST_ROOT}/test_alias_analysis.cpp
|
${JIT_TEST_ROOT}/test_alias_analysis.cpp
|
||||||
${JIT_TEST_ROOT}/test_argument_spec.cpp
|
${JIT_TEST_ROOT}/test_argument_spec.cpp
|
||||||
${JIT_TEST_ROOT}/test_autodiff.cpp
|
${JIT_TEST_ROOT}/test_autodiff.cpp
|
||||||
|
${JIT_TEST_ROOT}/test_load_upgraders.cpp
|
||||||
${JIT_TEST_ROOT}/test_op_replacement.cpp
|
${JIT_TEST_ROOT}/test_op_replacement.cpp
|
||||||
${JIT_TEST_ROOT}/test_upgrader_utils.cpp
|
${JIT_TEST_ROOT}/test_upgrader_utils.cpp
|
||||||
${JIT_TEST_ROOT}/test_backend.cpp
|
${JIT_TEST_ROOT}/test_backend.cpp
|
||||||
|
|
|
||||||
45
test/cpp/jit/test_load_upgraders.cpp
Normal file
45
test/cpp/jit/test_load_upgraders.cpp
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <torch/csrc/jit/api/module.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
|
|
||||||
|
#include <test/cpp/jit/test_utils.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
// Basic tests to check if C++ torch::jit::load
|
||||||
|
// can load the upgraders fine
|
||||||
|
// TODO (tugsuu) add more tests
|
||||||
|
TEST(UpgraderLoad, CanPopulateUpgradersGraph) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
b = 5
|
||||||
|
return torch.div(x, b)
|
||||||
|
)");
|
||||||
|
std::stringstream ms;
|
||||||
|
m.save(ms);
|
||||||
|
auto loaded_m = torch::jit::load(ms);
|
||||||
|
auto version_map = get_operator_version_map();
|
||||||
|
auto upgraders = dump_upgraders_map();
|
||||||
|
|
||||||
|
for (const auto& entry : version_map) {
|
||||||
|
auto list_of_upgraders_for_op = entry.second;
|
||||||
|
for (const auto& upgrader_entry : list_of_upgraders_for_op) {
|
||||||
|
EXPECT_TRUE(
|
||||||
|
upgraders.find(upgrader_entry.upgrader_name) != upgraders.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto test_graph = loaded_m.get_method("forward").graph();
|
||||||
|
// should have saved with version 4, so it is still up to date
|
||||||
|
testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
553
test/jit/test_legacy_upgraders.py
Normal file
553
test/jit/test_legacy_upgraders.py
Normal file
|
|
@ -0,0 +1,553 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
|
from itertools import product as product
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# Make the helper files in test/ importable
|
||||||
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
sys.path.append(pytorch_test_dir)
|
||||||
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test file is not meant to be run directly, use:\n\n"
|
||||||
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||||
|
"instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy test cases for dynamic operator versioning
|
||||||
|
class TestLegacyUpgraders(JitTestCase):
|
||||||
|
reason: str = "Skipped due to new global operator versioning for operators"
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_symbols(self):
|
||||||
|
"""
|
||||||
|
Tests Torchscript symbol versioning. See note [Versioned Symbols].
|
||||||
|
This test uses an undocumented, test-only function
|
||||||
|
torch._test_serialization_subcmul.
|
||||||
|
This function is implemented as (a - alpha * b) with a default value
|
||||||
|
of 1 for alpha. In file format version 2, however, it was implemented
|
||||||
|
as (b - alpha * a) with a default value of 2 for alpha.
|
||||||
|
This test verifies a module seralized with file format version 2
|
||||||
|
exhibits the old behavior, and that the same module newly serialized
|
||||||
|
exhibits the current behavior.
|
||||||
|
"""
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b, alpha: float):
|
||||||
|
no_alpha = torch._test_serialization_subcmul(a, b)
|
||||||
|
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
|
||||||
|
return no_alpha, with_alpha
|
||||||
|
|
||||||
|
def historic_subcmul(a, b, alpha=2):
|
||||||
|
return b - alpha * a
|
||||||
|
|
||||||
|
def current_subcmul(a, b, alpha=1):
|
||||||
|
return a - alpha * b
|
||||||
|
|
||||||
|
# Loads and verifies the historic behavior of the module
|
||||||
|
# that was serialized with version 2
|
||||||
|
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
||||||
|
a = torch.randn((5,))
|
||||||
|
b = torch.randn((5,))
|
||||||
|
alpha = random.random()
|
||||||
|
args = (a, b, alpha)
|
||||||
|
no_alpha_v2, with_alpha_v2 = module_v2(*args)
|
||||||
|
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
|
||||||
|
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
|
||||||
|
|
||||||
|
# Scripts, saves, loads and verifies the current behavior of the module
|
||||||
|
scripted_module = torch.jit.script(MyModule())
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(scripted_module, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
module_current = torch.jit.load(buffer)
|
||||||
|
no_alpha_current, with_alpha_current = module_current(*args)
|
||||||
|
self.assertEqual(no_alpha_current, current_subcmul(a, b))
|
||||||
|
self.assertEqual(with_alpha_current, current_subcmul(*args))
|
||||||
|
|
||||||
|
# Helper that returns the module after saving and loading
|
||||||
|
def _save_load_module(self, m):
|
||||||
|
scripted_module = torch.jit.script(m())
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(scripted_module, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return torch.jit.load(buffer)
|
||||||
|
|
||||||
|
# Helper which returns the result of a function or the exception the
|
||||||
|
# function threw.
|
||||||
|
def _try_fn(self, fn, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
|
||||||
|
def _verify_no(self, kind, m):
|
||||||
|
self._verify_count(kind, m, 0)
|
||||||
|
|
||||||
|
def _verify_count(self, kind, m, count):
|
||||||
|
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
|
||||||
|
self.assertEqual(node_count, count)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
||||||
|
to call either aten::true_divide(_), if an input is a float type,
|
||||||
|
or truncated aten::divide(_) otherwise.
|
||||||
|
NOTE: currently compares against current div behavior, too, since
|
||||||
|
div behavior has not yet been updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_tensor(self):
|
||||||
|
def historic_div(self, other):
|
||||||
|
if self.is_floating_point() or other.is_floating_point():
|
||||||
|
return self.true_divide(other)
|
||||||
|
return self.divide(other, rounding_mode='trunc')
|
||||||
|
|
||||||
|
# Tensor x Tensor
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b):
|
||||||
|
result_0 = a / b
|
||||||
|
result_1 = torch.div(a, b)
|
||||||
|
result_2 = a.div(b)
|
||||||
|
|
||||||
|
return result_0, result_1, result_2
|
||||||
|
|
||||||
|
# Loads historic module
|
||||||
|
try:
|
||||||
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
|
||||||
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::div", current_module, 3)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = torch.tensor((val_b,))
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
m_results = self._try_fn(m, a, b)
|
||||||
|
fn_result = self._try_fn(fn, a, b)
|
||||||
|
|
||||||
|
if isinstance(m_results, Exception):
|
||||||
|
self.assertTrue(isinstance(fn_result, Exception))
|
||||||
|
else:
|
||||||
|
for result in m_results:
|
||||||
|
self.assertEqual(result, fn_result)
|
||||||
|
|
||||||
|
_helper(v3_module, historic_div)
|
||||||
|
_helper(current_module, torch.div)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_tensor_inplace(self):
|
||||||
|
def historic_div_(self, other):
|
||||||
|
if self.is_floating_point() or other.is_floating_point():
|
||||||
|
return self.true_divide_(other)
|
||||||
|
return self.divide_(other, rounding_mode='trunc')
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b):
|
||||||
|
a /= b
|
||||||
|
return a
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
|
||||||
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::div", current_module, 1)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = torch.tensor((val_b,))
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
fn_result = self._try_fn(fn, a.clone(), b)
|
||||||
|
m_result = self._try_fn(m, a, b)
|
||||||
|
|
||||||
|
if isinstance(m_result, Exception):
|
||||||
|
self.assertTrue(fn_result, Exception)
|
||||||
|
else:
|
||||||
|
self.assertEqual(m_result, fn_result)
|
||||||
|
self.assertEqual(m_result, a)
|
||||||
|
|
||||||
|
_helper(v3_module, historic_div_)
|
||||||
|
|
||||||
|
# Recreates a since it was modified in place
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
_helper(current_module, torch.Tensor.div_)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_tensor_out(self):
|
||||||
|
def historic_div_out(self, other, out):
|
||||||
|
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
|
||||||
|
return torch.true_divide(self, other, out=out)
|
||||||
|
return torch.divide(self, other, out=out, rounding_mode='trunc')
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b, out):
|
||||||
|
return a.div(b, out=out)
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
|
||||||
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::div", current_module, 1)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = torch.tensor((val_b,))
|
||||||
|
|
||||||
|
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
|
||||||
|
def _helper(m, fn):
|
||||||
|
fn_result = None
|
||||||
|
if fn is torch.div:
|
||||||
|
fn_result = self._try_fn(fn, a, b, out=out.clone())
|
||||||
|
else:
|
||||||
|
fn_result = self._try_fn(fn, a, b, out.clone())
|
||||||
|
m_result = self._try_fn(m, a, b, out)
|
||||||
|
|
||||||
|
if isinstance(m_result, Exception):
|
||||||
|
self.assertTrue(fn_result, Exception)
|
||||||
|
else:
|
||||||
|
self.assertEqual(m_result, fn_result)
|
||||||
|
self.assertEqual(m_result, out)
|
||||||
|
|
||||||
|
_helper(v3_module, historic_div_out)
|
||||||
|
_helper(current_module, torch.div)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_scalar(self):
|
||||||
|
def historic_div_scalar_float(self, other: float):
|
||||||
|
return torch.true_divide(self, other)
|
||||||
|
|
||||||
|
def historic_div_scalar_int(self, other: int):
|
||||||
|
if self.is_floating_point():
|
||||||
|
return torch.true_divide(self, other)
|
||||||
|
return torch.divide(self, other, rounding_mode='trunc')
|
||||||
|
|
||||||
|
class MyModuleFloat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleFloat, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: float):
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
class MyModuleInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleInt, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: int):
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
|
||||||
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
for m in (v3_module_float, v3_module_int):
|
||||||
|
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
||||||
|
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||||
|
|
||||||
|
current_module_float = self._save_load_module(MyModuleFloat)
|
||||||
|
current_module_int = self._save_load_module(MyModuleInt)
|
||||||
|
|
||||||
|
for m in (current_module_float, current_module_int):
|
||||||
|
self._verify_count("aten::div", m, 1)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = val_b
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
m_result = self._try_fn(m, a, b)
|
||||||
|
fn_result = self._try_fn(fn, a, b)
|
||||||
|
|
||||||
|
if isinstance(m_result, Exception):
|
||||||
|
self.assertTrue(fn_result, Exception)
|
||||||
|
else:
|
||||||
|
self.assertEqual(m_result, fn_result)
|
||||||
|
|
||||||
|
if isinstance(b, float):
|
||||||
|
_helper(v3_module_float, historic_div_scalar_float)
|
||||||
|
_helper(current_module_float, torch.div)
|
||||||
|
else:
|
||||||
|
_helper(v3_module_int, historic_div_scalar_int)
|
||||||
|
_helper(current_module_int, torch.div)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_scalar_reciprocal(self):
|
||||||
|
def historic_div_scalar_float_reciprocal(self, other: float):
|
||||||
|
return other / self
|
||||||
|
|
||||||
|
def historic_div_scalar_int_reciprocal(self, other: int):
|
||||||
|
if self.is_floating_point():
|
||||||
|
return other / self
|
||||||
|
return torch.divide(other, self, rounding_mode='trunc')
|
||||||
|
|
||||||
|
class MyModuleFloat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleFloat, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: float):
|
||||||
|
return b / a
|
||||||
|
|
||||||
|
class MyModuleInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleInt, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: int):
|
||||||
|
return b / a
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
|
||||||
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
|
||||||
|
# so true_divide and floor_divide do not appear in their graphs
|
||||||
|
for m in (v3_module_float, v3_module_int):
|
||||||
|
self._verify_no("aten::div", m)
|
||||||
|
self._verify_no("aten::true_divide", m)
|
||||||
|
self._verify_no("aten::floor_divide", m)
|
||||||
|
self._verify_count("aten::reciprocal", m, 1)
|
||||||
|
|
||||||
|
current_module_float = self._save_load_module(MyModuleFloat)
|
||||||
|
current_module_int = self._save_load_module(MyModuleInt)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = val_b
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
m_result = self._try_fn(m, a, b)
|
||||||
|
fn_result = None
|
||||||
|
# Reverses argument order for torch.div
|
||||||
|
if fn is torch.div:
|
||||||
|
fn_result = self._try_fn(torch.div, b, a)
|
||||||
|
else:
|
||||||
|
fn_result = self._try_fn(fn, a, b)
|
||||||
|
|
||||||
|
if isinstance(m_result, Exception):
|
||||||
|
self.assertTrue(isinstance(fn_result, Exception))
|
||||||
|
elif fn is torch.div or a.is_floating_point():
|
||||||
|
self.assertEqual(m_result, fn_result)
|
||||||
|
else:
|
||||||
|
# Skip when fn is not torch.div and a is integral because
|
||||||
|
# historic_div_scalar_int performs floored division
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(b, float):
|
||||||
|
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
|
||||||
|
_helper(current_module_float, torch.div)
|
||||||
|
else:
|
||||||
|
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
|
||||||
|
_helper(current_module_int, torch.div)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_scalar_inplace(self):
|
||||||
|
def historic_div_scalar_float_inplace(self, other: float):
|
||||||
|
return self.true_divide_(other)
|
||||||
|
|
||||||
|
def historic_div_scalar_int_inplace(self, other: int):
|
||||||
|
if self.is_floating_point():
|
||||||
|
return self.true_divide_(other)
|
||||||
|
|
||||||
|
return self.divide_(other, rounding_mode='trunc')
|
||||||
|
|
||||||
|
class MyModuleFloat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleFloat, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: float):
|
||||||
|
a /= b
|
||||||
|
return a
|
||||||
|
|
||||||
|
class MyModuleInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModuleInt, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a, b: int):
|
||||||
|
a /= b
|
||||||
|
return a
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
|
||||||
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
for m in (v3_module_float, v3_module_int):
|
||||||
|
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
|
||||||
|
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||||
|
|
||||||
|
current_module_float = self._save_load_module(MyModuleFloat)
|
||||||
|
current_module_int = self._save_load_module(MyModuleInt)
|
||||||
|
|
||||||
|
for m in (current_module_float, current_module_int):
|
||||||
|
self._verify_count("aten::div", m, 1)
|
||||||
|
|
||||||
|
for m in (current_module_float, current_module_int):
|
||||||
|
self._verify_count("aten::div", m, 1)
|
||||||
|
|
||||||
|
vals = (2., 3., 2, 3)
|
||||||
|
for val_a, val_b in product(vals, vals):
|
||||||
|
a = torch.tensor((val_a,))
|
||||||
|
b = val_b
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
m_result = self._try_fn(m, a, b)
|
||||||
|
fn_result = self._try_fn(fn, a, b)
|
||||||
|
|
||||||
|
if isinstance(m_result, Exception):
|
||||||
|
self.assertTrue(fn_result, Exception)
|
||||||
|
else:
|
||||||
|
self.assertEqual(m_result, fn_result)
|
||||||
|
|
||||||
|
if isinstance(b, float):
|
||||||
|
_helper(v3_module_float, historic_div_scalar_float_inplace)
|
||||||
|
_helper(current_module_float, torch.Tensor.div_)
|
||||||
|
else:
|
||||||
|
_helper(v3_module_int, historic_div_scalar_int_inplace)
|
||||||
|
_helper(current_module_int, torch.Tensor.div_)
|
||||||
|
|
||||||
|
# NOTE: Scalar division was already true division in op version 3,
|
||||||
|
# so this test verifies the behavior is unchanged.
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_div_scalar_scalar(self):
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, a: float, b: int, c: float, d: int):
|
||||||
|
result_0 = a / b
|
||||||
|
result_1 = a / c
|
||||||
|
result_2 = b / c
|
||||||
|
result_3 = b / d
|
||||||
|
return (result_0, result_1, result_2, result_3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::div", v3_module, 4)
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::div", current_module, 4)
|
||||||
|
|
||||||
|
def _helper(m, fn):
|
||||||
|
vals = (5., 3, 2., 7)
|
||||||
|
m_result = m(*vals)
|
||||||
|
fn_result = fn(*vals)
|
||||||
|
for mr, hr in zip(m_result, fn_result):
|
||||||
|
self.assertEqual(mr, hr)
|
||||||
|
|
||||||
|
_helper(v3_module, current_module)
|
||||||
|
|
||||||
|
# NOTE: the JIT was incapable of handling boolean fill values when
|
||||||
|
# PyTorch produced file format versions 0-4
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_full_integer_value(self):
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, int_fill: int):
|
||||||
|
size = torch.Size(2, 2)
|
||||||
|
a = torch.full(size, int_fill)
|
||||||
|
b = torch.full(size, 1)
|
||||||
|
return (a, b)
|
||||||
|
|
||||||
|
try:
|
||||||
|
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::full", v4_module, 2)
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::full", current_module, 2)
|
||||||
|
|
||||||
|
# Verifies historic integer type inference is float
|
||||||
|
# NOTE: only verifies floating point, not exact dtype, due to
|
||||||
|
# https://github.com/pytorch/pytorch/issues/40470
|
||||||
|
results = v4_module(2)
|
||||||
|
for result in results:
|
||||||
|
self.assertTrue(result.is_floating_point())
|
||||||
|
|
||||||
|
# Verifies values are correct
|
||||||
|
a, b = results
|
||||||
|
self.assertTrue((a == 2.).all())
|
||||||
|
self.assertTrue((b == 1.).all())
|
||||||
|
|
||||||
|
# Tests that torch.full behavior which is the same from prior versions
|
||||||
|
# to version 5 is preserved.
|
||||||
|
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
|
||||||
|
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
|
||||||
|
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||||
|
def test_versioned_full_preserved(self):
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, float_fill: float):
|
||||||
|
size = (2, 2)
|
||||||
|
a = torch.full(size, 1.)
|
||||||
|
b = torch.full(size, float_fill)
|
||||||
|
c = torch.full(size, float_fill, dtype=torch.long)
|
||||||
|
|
||||||
|
out = torch.empty(size, dtype=torch.long)
|
||||||
|
d = torch.full(size, float_fill, out=out)
|
||||||
|
|
||||||
|
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
|
||||||
|
layout=torch.strided, device='cpu')
|
||||||
|
return (a, b, c, d, e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
|
||||||
|
except Exception as e:
|
||||||
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
|
self._verify_count("aten::full", v4_module, 5)
|
||||||
|
|
||||||
|
current_module = self._save_load_module(MyModule)
|
||||||
|
self._verify_count("aten::full", current_module, 5)
|
||||||
|
|
||||||
|
self.assertEqual(v4_module(2.), current_module(2.))
|
||||||
|
|
@ -24,23 +24,6 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
class TestSaveLoad(JitTestCase):
|
class TestSaveLoad(JitTestCase):
|
||||||
|
|
||||||
def test_versioned_symbols_reserialization(self):
|
|
||||||
"""
|
|
||||||
Tests that loading and saving serialized Torchscript with a versioned
|
|
||||||
symbol won't persist the original function and will inline the
|
|
||||||
versioned builtin.
|
|
||||||
"""
|
|
||||||
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torch.jit.save(module_v2, buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
module_reserialized = torch.jit.load(buffer)
|
|
||||||
|
|
||||||
subcmul_nodes = sum("subcmul" in n.kind() for
|
|
||||||
n in module_reserialized.graph.nodes())
|
|
||||||
self.assertEqual(subcmul_nodes, 0)
|
|
||||||
|
|
||||||
def test_different_modules(self):
|
def test_different_modules(self):
|
||||||
"""
|
"""
|
||||||
Exercise the situation where we have the same qualified name
|
Exercise the situation where we have the same qualified name
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
from itertools import product as product
|
from itertools import product as product
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import hypothesis.strategies as st
|
import hypothesis.strategies as st
|
||||||
from hypothesis import example, settings, given
|
from hypothesis import example, settings, given
|
||||||
|
|
@ -24,55 +23,6 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
class TestSaveLoadForOpVersion(JitTestCase):
|
class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
def test_versioned_symbols(self):
|
|
||||||
"""
|
|
||||||
Tests Torchscript symbol versioning. See note [Versioned Symbols].
|
|
||||||
This test uses an undocumented, test-only function
|
|
||||||
torch._test_serialization_subcmul.
|
|
||||||
|
|
||||||
This function is implemented as (a - alpha * b) with a default value
|
|
||||||
of 1 for alpha. In file format version 2, however, it was implemented
|
|
||||||
as (b - alpha * a) with a default value of 2 for alpha.
|
|
||||||
This test verifies a module seralized with file format version 2
|
|
||||||
exhibits the old behavior, and that the same module newly serialized
|
|
||||||
exhibits the current behavior.
|
|
||||||
"""
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModule, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, a, b, alpha: float):
|
|
||||||
no_alpha = torch._test_serialization_subcmul(a, b)
|
|
||||||
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
|
|
||||||
return no_alpha, with_alpha
|
|
||||||
|
|
||||||
def historic_subcmul(a, b, alpha=2):
|
|
||||||
return b - alpha * a
|
|
||||||
|
|
||||||
def current_subcmul(a, b, alpha=1):
|
|
||||||
return a - alpha * b
|
|
||||||
|
|
||||||
# Loads and verifies the historic behavior of the module
|
|
||||||
# that was serialized with version 2
|
|
||||||
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
|
||||||
a = torch.randn((5,))
|
|
||||||
b = torch.randn((5,))
|
|
||||||
alpha = random.random()
|
|
||||||
args = (a, b, alpha)
|
|
||||||
no_alpha_v2, with_alpha_v2 = module_v2(*args)
|
|
||||||
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
|
|
||||||
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
|
|
||||||
|
|
||||||
# Scripts, saves, loads and verifies the current behavior of the module
|
|
||||||
scripted_module = torch.jit.script(MyModule())
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torch.jit.save(scripted_module, buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
module_current = torch.jit.load(buffer)
|
|
||||||
no_alpha_current, with_alpha_current = module_current(*args)
|
|
||||||
self.assertEqual(no_alpha_current, current_subcmul(a, b))
|
|
||||||
self.assertEqual(with_alpha_current, current_subcmul(*args))
|
|
||||||
|
|
||||||
# Helper that returns the module after saving and loading
|
# Helper that returns the module after saving and loading
|
||||||
def _save_load_module(self, m):
|
def _save_load_module(self, m):
|
||||||
scripted_module = torch.jit.script(m())
|
scripted_module = torch.jit.script(m())
|
||||||
|
|
@ -107,7 +57,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
||||||
to call either aten::true_divide(_), if an input is a float type,
|
to call either aten::true_divide(_), if an input is a float type,
|
||||||
or truncated aten::divide(_) otherwise.
|
or truncated aten::divide(_) otherwise.
|
||||||
|
|
||||||
NOTE: currently compares against current div behavior, too, since
|
NOTE: currently compares against current div behavior, too, since
|
||||||
div behavior has not yet been updated.
|
div behavior has not yet been updated.
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,18 +86,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
|
|
||||||
# Loads historic module
|
# Loads historic module
|
||||||
try:
|
try:
|
||||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
|
|
||||||
v3_mobile_module = _load_for_lite_interpreter(
|
v3_mobile_module = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
|
|
||||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||||
self._verify_count("aten::div", current_module, 3)
|
|
||||||
|
|
||||||
for val_a, val_b in product(sample_input, sample_input):
|
for val_a, val_b in product(sample_input, sample_input):
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
|
|
@ -164,9 +107,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
for result in m_results:
|
for result in m_results:
|
||||||
self.assertEqual(result, fn_result)
|
self.assertEqual(result, fn_result)
|
||||||
|
|
||||||
_helper(v3_module, historic_div)
|
|
||||||
_helper(v3_mobile_module, historic_div)
|
_helper(v3_mobile_module, historic_div)
|
||||||
_helper(current_module, torch.div)
|
|
||||||
_helper(current_mobile_module, torch.div)
|
_helper(current_mobile_module, torch.div)
|
||||||
|
|
||||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||||
|
|
@ -189,18 +130,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return a
|
return a
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
|
|
||||||
v3_mobile_module = _load_for_lite_interpreter(
|
v3_mobile_module = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
|
|
||||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||||
self._verify_count("aten::div", current_module, 1)
|
|
||||||
|
|
||||||
for val_a, val_b in product(sample_input, sample_input):
|
for val_a, val_b in product(sample_input, sample_input):
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
|
|
@ -215,12 +150,10 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
self.assertEqual(m_result, fn_result)
|
self.assertEqual(m_result, fn_result)
|
||||||
self.assertEqual(m_result, a)
|
self.assertEqual(m_result, a)
|
||||||
|
|
||||||
_helper(v3_module, historic_div_)
|
|
||||||
_helper(v3_mobile_module, historic_div_)
|
_helper(v3_mobile_module, historic_div_)
|
||||||
|
|
||||||
# Recreates a since it was modified in place
|
# Recreates a since it was modified in place
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
_helper(current_module, torch.Tensor.div_)
|
|
||||||
_helper(current_mobile_module, torch.Tensor.div_)
|
_helper(current_mobile_module, torch.Tensor.div_)
|
||||||
|
|
||||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||||
|
|
@ -242,18 +175,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return a.div(b, out=out)
|
return a.div(b, out=out)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
|
|
||||||
v3_mobile_module = _load_for_lite_interpreter(
|
v3_mobile_module = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
|
|
||||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||||
self._verify_count("aten::div", current_module, 1)
|
|
||||||
|
|
||||||
for val_a, val_b in product(sample_input, sample_input):
|
for val_a, val_b in product(sample_input, sample_input):
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
|
|
@ -274,8 +201,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
self.assertEqual(m_result, fn_result)
|
self.assertEqual(m_result, fn_result)
|
||||||
self.assertEqual(m_result, out)
|
self.assertEqual(m_result, out)
|
||||||
|
|
||||||
_helper(v3_module, historic_div_out)
|
|
||||||
_helper(current_module, torch.div)
|
|
||||||
_helper(v3_mobile_module, historic_div_out)
|
_helper(v3_mobile_module, historic_div_out)
|
||||||
_helper(current_mobile_module, torch.div)
|
_helper(current_mobile_module, torch.div)
|
||||||
|
|
||||||
|
|
@ -308,8 +233,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return a / b
|
return a / b
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
|
|
||||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v3.pt")
|
|
||||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
|
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
|
||||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||||
|
|
@ -321,14 +244,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
||||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||||
|
|
||||||
current_module_float = self._save_load_module(MyModuleFloat)
|
|
||||||
current_module_int = self._save_load_module(MyModuleInt)
|
|
||||||
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
||||||
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
||||||
|
|
||||||
for m in (current_module_float, current_module_int):
|
|
||||||
self._verify_count("aten::div", m, 1)
|
|
||||||
|
|
||||||
for val_a, val_b in product(sample_input, sample_input):
|
for val_a, val_b in product(sample_input, sample_input):
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
b = val_b
|
b = val_b
|
||||||
|
|
@ -343,13 +261,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
self.assertEqual(m_result, fn_result)
|
self.assertEqual(m_result, fn_result)
|
||||||
|
|
||||||
if isinstance(b, float):
|
if isinstance(b, float):
|
||||||
_helper(v3_module_float, historic_div_scalar_float)
|
|
||||||
_helper(current_module_float, torch.div)
|
|
||||||
_helper(v3_mobile_module_float, current_mobile_module_float)
|
_helper(v3_mobile_module_float, current_mobile_module_float)
|
||||||
_helper(current_mobile_module_float, torch.div)
|
_helper(current_mobile_module_float, torch.div)
|
||||||
else:
|
else:
|
||||||
_helper(v3_module_int, historic_div_scalar_int)
|
|
||||||
_helper(current_module_int, torch.div)
|
|
||||||
_helper(v3_mobile_module_int, historic_div_scalar_int)
|
_helper(v3_mobile_module_int, historic_div_scalar_int)
|
||||||
_helper(current_mobile_module_int, torch.div)
|
_helper(current_mobile_module_int, torch.div)
|
||||||
|
|
||||||
|
|
@ -382,8 +296,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return b / a
|
return b / a
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
|
|
||||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
|
|
||||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
|
||||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||||
|
|
@ -391,17 +303,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
|
|
||||||
# so true_divide and floor_divide do not appear in their graphs
|
|
||||||
for m in (v3_module_float, v3_module_int):
|
|
||||||
self._verify_no("aten::div", m)
|
|
||||||
self._verify_no("aten::true_divide", m)
|
|
||||||
self._verify_no("aten::floor_divide", m)
|
|
||||||
self._verify_count("aten::reciprocal", m, 1)
|
|
||||||
|
|
||||||
current_module_float = self._save_load_module(MyModuleFloat)
|
|
||||||
current_module_int = self._save_load_module(MyModuleInt)
|
|
||||||
|
|
||||||
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
||||||
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
||||||
|
|
||||||
|
|
@ -428,13 +329,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if isinstance(b, float):
|
if isinstance(b, float):
|
||||||
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
|
|
||||||
_helper(current_module_float, torch.div)
|
|
||||||
_helper(v3_mobile_module_float, current_mobile_module_float)
|
_helper(v3_mobile_module_float, current_mobile_module_float)
|
||||||
_helper(current_mobile_module_float, torch.div)
|
_helper(current_mobile_module_float, torch.div)
|
||||||
else:
|
else:
|
||||||
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
|
|
||||||
_helper(current_module_int, torch.div)
|
|
||||||
_helper(v3_mobile_module_int, current_mobile_module_int)
|
_helper(v3_mobile_module_int, current_mobile_module_int)
|
||||||
_helper(current_mobile_module_int, torch.div)
|
_helper(current_mobile_module_int, torch.div)
|
||||||
|
|
||||||
|
|
@ -470,9 +367,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return a
|
return a
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
|
|
||||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
|
|
||||||
|
|
||||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
|
||||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||||
|
|
@ -480,22 +374,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
for m in (v3_module_float, v3_module_int):
|
|
||||||
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
|
|
||||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
|
||||||
|
|
||||||
current_module_float = self._save_load_module(MyModuleFloat)
|
|
||||||
current_module_int = self._save_load_module(MyModuleInt)
|
|
||||||
|
|
||||||
current_mobile_module_float = self._save_load_module(MyModuleFloat)
|
current_mobile_module_float = self._save_load_module(MyModuleFloat)
|
||||||
current_mobile_module_int = self._save_load_module(MyModuleInt)
|
current_mobile_module_int = self._save_load_module(MyModuleInt)
|
||||||
|
|
||||||
for m in (current_module_float, current_module_int):
|
|
||||||
self._verify_count("aten::div", m, 1)
|
|
||||||
|
|
||||||
for m in (current_module_float, current_module_int):
|
|
||||||
self._verify_count("aten::div", m, 1)
|
|
||||||
|
|
||||||
for val_a, val_b in product(sample_input, sample_input):
|
for val_a, val_b in product(sample_input, sample_input):
|
||||||
a = torch.tensor((val_a,))
|
a = torch.tensor((val_a,))
|
||||||
b = val_b
|
b = val_b
|
||||||
|
|
@ -510,12 +391,8 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
self.assertEqual(m_result, fn_result)
|
self.assertEqual(m_result, fn_result)
|
||||||
|
|
||||||
if isinstance(b, float):
|
if isinstance(b, float):
|
||||||
_helper(v3_module_float, historic_div_scalar_float_inplace)
|
|
||||||
_helper(current_module_float, torch.Tensor.div_)
|
|
||||||
_helper(current_mobile_module_float, torch.Tensor.div_)
|
_helper(current_mobile_module_float, torch.Tensor.div_)
|
||||||
else:
|
else:
|
||||||
_helper(v3_module_int, historic_div_scalar_int_inplace)
|
|
||||||
_helper(current_module_int, torch.Tensor.div_)
|
|
||||||
_helper(current_mobile_module_int, torch.Tensor.div_)
|
_helper(current_mobile_module_int, torch.Tensor.div_)
|
||||||
|
|
||||||
# NOTE: Scalar division was already true division in op version 3,
|
# NOTE: Scalar division was already true division in op version 3,
|
||||||
|
|
@ -533,17 +410,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
return (result_0, result_1, result_2, result_3)
|
return (result_0, result_1, result_2, result_3)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
|
|
||||||
v3_mobile_module = _load_for_lite_interpreter(
|
v3_mobile_module = _load_for_lite_interpreter(
|
||||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
|
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.skipTest("Failed to load fixture!")
|
self.skipTest("Failed to load fixture!")
|
||||||
|
|
||||||
self._verify_count("aten::div", v3_module, 4)
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||||
self._verify_count("aten::div", current_module, 4)
|
|
||||||
|
|
||||||
def _helper(m, fn):
|
def _helper(m, fn):
|
||||||
vals = (5., 3, 2., 7)
|
vals = (5., 3, 2., 7)
|
||||||
|
|
@ -552,74 +424,4 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||||
for mr, hr in zip(m_result, fn_result):
|
for mr, hr in zip(m_result, fn_result):
|
||||||
self.assertEqual(mr, hr)
|
self.assertEqual(mr, hr)
|
||||||
|
|
||||||
_helper(v3_module, current_module)
|
|
||||||
_helper(v3_mobile_module, current_mobile_module)
|
_helper(v3_mobile_module, current_mobile_module)
|
||||||
|
|
||||||
# NOTE: the JIT was incapable of handling boolean fill values when
|
|
||||||
# PyTorch produced file format versions 0-4
|
|
||||||
def test_versioned_full_integer_value(self):
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModule, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, int_fill: int):
|
|
||||||
size = torch.Size(2, 2)
|
|
||||||
a = torch.full(size, int_fill)
|
|
||||||
b = torch.full(size, 1)
|
|
||||||
return (a, b)
|
|
||||||
|
|
||||||
try:
|
|
||||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
|
|
||||||
except Exception as e:
|
|
||||||
self.skipTest("Failed to load fixture!")
|
|
||||||
|
|
||||||
self._verify_count("aten::full", v4_module, 2)
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
self._verify_count("aten::full", current_module, 2)
|
|
||||||
|
|
||||||
# Verifies historic integer type inference is float
|
|
||||||
# NOTE: only verifies floating point, not exact dtype, due to
|
|
||||||
# https://github.com/pytorch/pytorch/issues/40470
|
|
||||||
results = v4_module(2)
|
|
||||||
for result in results:
|
|
||||||
self.assertTrue(result.is_floating_point())
|
|
||||||
|
|
||||||
# Verifies values are correct
|
|
||||||
a, b = results
|
|
||||||
self.assertTrue((a == 2.).all())
|
|
||||||
self.assertTrue((b == 1.).all())
|
|
||||||
|
|
||||||
# Tests that torch.full behavior which is the same from prior versions
|
|
||||||
# to version 5 is preserved.
|
|
||||||
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
|
|
||||||
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
|
|
||||||
def test_versioned_full_preserved(self):
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModule, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, float_fill: float):
|
|
||||||
size = (2, 2)
|
|
||||||
a = torch.full(size, 1.)
|
|
||||||
b = torch.full(size, float_fill)
|
|
||||||
c = torch.full(size, float_fill, dtype=torch.long)
|
|
||||||
|
|
||||||
out = torch.empty(size, dtype=torch.long)
|
|
||||||
d = torch.full(size, float_fill, out=out)
|
|
||||||
|
|
||||||
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
|
|
||||||
layout=torch.strided, device='cpu')
|
|
||||||
return (a, b, c, d, e)
|
|
||||||
|
|
||||||
try:
|
|
||||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
|
|
||||||
except Exception as e:
|
|
||||||
self.skipTest("Failed to load fixture!")
|
|
||||||
|
|
||||||
self._verify_count("aten::full", v4_module, 5)
|
|
||||||
|
|
||||||
current_module = self._save_load_module(MyModule)
|
|
||||||
self._verify_count("aten::full", current_module, 5)
|
|
||||||
|
|
||||||
self.assertEqual(v4_module(2.), current_module(2.))
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,11 @@ import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
import unittest
|
||||||
|
import zipfile
|
||||||
|
from torch.testing import FileCheck
|
||||||
|
from torch._C import _is_upgraders_enabled
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
# Make the helper files in test/ importable
|
# Make the helper files in test/ importable
|
||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
@ -16,6 +21,15 @@ if __name__ == '__main__':
|
||||||
"instead.")
|
"instead.")
|
||||||
|
|
||||||
class TestUpgraders(JitTestCase):
|
class TestUpgraders(JitTestCase):
|
||||||
|
def _load_model_version(self, loaded_model):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
zipped_model = zipfile.ZipFile(buffer)
|
||||||
|
version = int(zipped_model.read('archive/version').decode("utf-8"))
|
||||||
|
return version
|
||||||
|
|
||||||
|
# TODO (tugsuu) We should ideally be generating this test cases.
|
||||||
def test_populated_upgrader_graph(self):
|
def test_populated_upgrader_graph(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def f():
|
def f():
|
||||||
|
|
@ -67,7 +81,7 @@ class TestUpgraders(JitTestCase):
|
||||||
# upgrader map should have populated now
|
# upgrader map should have populated now
|
||||||
upgraders_size = torch._C._get_upgraders_map_size()
|
upgraders_size = torch._C._get_upgraders_map_size()
|
||||||
|
|
||||||
test_map = {"a": "b", "c": "d"}
|
test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
|
||||||
torch._C._test_only_populate_upgraders(test_map)
|
torch._C._test_only_populate_upgraders(test_map)
|
||||||
upgraders_size_after_test = torch._C._get_upgraders_map_size()
|
upgraders_size_after_test = torch._C._get_upgraders_map_size()
|
||||||
self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
|
self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
|
||||||
|
|
@ -81,3 +95,117 @@ class TestUpgraders(JitTestCase):
|
||||||
upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
|
upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
|
||||||
self.assertTrue("a" not in upgraders_dump_after_remove_test)
|
self.assertTrue("a" not in upgraders_dump_after_remove_test)
|
||||||
self.assertTrue("c" not in upgraders_dump_after_remove_test)
|
self.assertTrue("c" not in upgraders_dump_after_remove_test)
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_div_tensor_at_3(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
# there are 3 aten::div in this model
|
||||||
|
# And the upgrader for aten::div uses two
|
||||||
|
# div's because of if/else branch
|
||||||
|
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||||
|
FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 4)
|
||||||
|
loaded_model_twice = torch.jit.load(buffer)
|
||||||
|
# we check by its code because graph variable names
|
||||||
|
# can be different every time
|
||||||
|
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_test_serialization(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
|
||||||
|
|
||||||
|
# add test version entry to the version map
|
||||||
|
upgrader_bumped_version = 3
|
||||||
|
upgrader_name = "_test_serialization_subcmul_0_2"
|
||||||
|
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
|
||||||
|
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
|
||||||
|
|
||||||
|
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
|
||||||
|
|
||||||
|
# add test upgrader in the upgraders map
|
||||||
|
@torch.jit.script
|
||||||
|
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
|
||||||
|
return other - (self * alpha)
|
||||||
|
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||||
|
|
||||||
|
# test if the server is able to find the test upgraders and apply to IR
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
|
||||||
|
FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 3)
|
||||||
|
loaded_model_twice = torch.jit.load(buffer)
|
||||||
|
# we check by its' code because graph variable names
|
||||||
|
# can be different every time
|
||||||
|
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||||
|
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
|
||||||
|
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_div_scalar_at_3(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||||
|
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 4)
|
||||||
|
loaded_model_twice = torch.jit.load(buffer)
|
||||||
|
|
||||||
|
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
|
||||||
|
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_div_tensor_out_at_3(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||||
|
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 4)
|
||||||
|
loaded_model_twice = torch.jit.load(buffer)
|
||||||
|
# we check by its' code because graph variable names
|
||||||
|
# can be different every time
|
||||||
|
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_full_at_4(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
|
||||||
|
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(loaded_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 5)
|
||||||
|
loaded_model_twice = torch.jit.load(buffer)
|
||||||
|
# we check by its' code because graph variable names
|
||||||
|
# can be different every time
|
||||||
|
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||||
|
|
||||||
|
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||||
|
def test_aten_full_out_at_4(self):
|
||||||
|
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
|
||||||
|
loaded_model = torch.jit.load(model_path)
|
||||||
|
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
|
||||||
|
version = self._load_model_version(loaded_model)
|
||||||
|
self.assertTrue(version == 5)
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401
|
||||||
from jit.test_aten_pow import TestAtenPow # noqa: F401
|
from jit.test_aten_pow import TestAtenPow # noqa: F401
|
||||||
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
|
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
|
||||||
from jit.test_union import TestUnion # noqa: F401
|
from jit.test_union import TestUnion # noqa: F401
|
||||||
|
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
|
||||||
from jit.test_models import MnistNet
|
from jit.test_models import MnistNet
|
||||||
from jit.test_batch_mm import TestBatchMM # noqa: F401
|
from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||||
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,7 @@ core_sources_common = [
|
||||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||||
"torch/csrc/jit/mobile/runtime_compatibility.cpp",
|
"torch/csrc/jit/mobile/runtime_compatibility.cpp",
|
||||||
"torch/csrc/jit/operator_upgraders/version_map.cpp",
|
"torch/csrc/jit/operator_upgraders/version_map.cpp",
|
||||||
|
"torch/csrc/jit/operator_upgraders/upgraders_guard.cpp",
|
||||||
"torch/csrc/jit/runtime/instruction.cpp",
|
"torch/csrc/jit/runtime/instruction.cpp",
|
||||||
"torch/csrc/jit/runtime/jit_exception.cpp",
|
"torch/csrc/jit/runtime/jit_exception.cpp",
|
||||||
"torch/csrc/jit/runtime/operator.cpp",
|
"torch/csrc/jit/runtime/operator.cpp",
|
||||||
|
|
@ -121,7 +122,6 @@ core_sources_common = [
|
||||||
"torch/csrc/jit/runtime/vararg_functions.cpp",
|
"torch/csrc/jit/runtime/vararg_functions.cpp",
|
||||||
"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
|
"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
|
||||||
"torch/csrc/jit/mobile/prim_ops_registery.cpp",
|
"torch/csrc/jit/mobile/prim_ops_registery.cpp",
|
||||||
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
|
||||||
"torch/csrc/profiler/util.cpp",
|
"torch/csrc/profiler/util.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -209,6 +209,8 @@ core_sources_full_mobile_no_backend_interface = [
|
||||||
"torch/csrc/jit/mobile/nnc/context.cpp",
|
"torch/csrc/jit/mobile/nnc/context.cpp",
|
||||||
"torch/csrc/jit/mobile/nnc/registry.cpp",
|
"torch/csrc/jit/mobile/nnc/registry.cpp",
|
||||||
"torch/csrc/jit/operator_upgraders/utils.cpp",
|
"torch/csrc/jit/operator_upgraders/utils.cpp",
|
||||||
|
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
|
||||||
|
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
||||||
"torch/csrc/jit/passes/annotate_warns.cpp",
|
"torch/csrc/jit/passes/annotate_warns.cpp",
|
||||||
"torch/csrc/jit/passes/bailout_graph.cpp",
|
"torch/csrc/jit/passes/bailout_graph.cpp",
|
||||||
"torch/csrc/jit/passes/batch_mm.cpp",
|
"torch/csrc/jit/passes/batch_mm.cpp",
|
||||||
|
|
|
||||||
|
|
@ -293,7 +293,7 @@ def _create_function_from_trace(
|
||||||
def _jit_is_script_object(obj: Any) -> _bool: ...
|
def _jit_is_script_object(obj: Any) -> _bool: ...
|
||||||
def _last_executed_optimized_graph() -> Graph: ...
|
def _last_executed_optimized_graph() -> Graph: ...
|
||||||
def parse_type_comment(comment: str) -> Decl: ...
|
def parse_type_comment(comment: str) -> Decl: ...
|
||||||
def _populate_upgraders_map(content: Dict[str, str]) -> None: ...
|
def _is_upgraders_enabled() -> _bool: ...
|
||||||
def _get_upgraders_map_size() -> _int: ...
|
def _get_upgraders_map_size() -> _int: ...
|
||||||
def _dump_upgraders_map() -> Dict[str, str]: ...
|
def _dump_upgraders_map() -> Dict[str, str]: ...
|
||||||
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
||||||
|
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/api/include/torch/jit.h>
|
#include <torch/csrc/api/include/torch/jit.h>
|
||||||
#include <torch/csrc/jit/frontend/code_template.h>
|
#include <torch/csrc/jit/frontend/code_template.h>
|
||||||
#include <torch/csrc/jit/frontend/resolver.h>
|
#include <torch/csrc/jit/frontend/resolver.h>
|
||||||
|
|
@ -85,6 +86,7 @@ def __contains__(self: str, key: str):
|
||||||
return self.find(key, 0, len(self)) != -1
|
return self.find(key, 0, len(self)) != -1
|
||||||
)SCRIPT";
|
)SCRIPT";
|
||||||
|
|
||||||
|
#if !ENABLE_UPGRADERS
|
||||||
// Implementations of historic symbol behaviors are defined here
|
// Implementations of historic symbol behaviors are defined here
|
||||||
// See note [Versioned Symbols]
|
// See note [Versioned Symbols]
|
||||||
|
|
||||||
|
|
@ -158,7 +160,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
||||||
pin_memory:Optional[bool]=None) -> Tensor:
|
pin_memory:Optional[bool]=None) -> Tensor:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
fill_value = float(fill_value)
|
fill_value = float(fill_value)
|
||||||
|
|
||||||
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
||||||
)SCRIPT";
|
)SCRIPT";
|
||||||
|
|
||||||
|
|
@ -168,6 +169,7 @@ auto full_out = R"SCRIPT(
|
||||||
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
||||||
return torch.full(size, fill_value, out=out)
|
return torch.full(size, fill_value, out=out)
|
||||||
)SCRIPT";
|
)SCRIPT";
|
||||||
|
#endif
|
||||||
|
|
||||||
struct BuiltinFunctionRegistry {
|
struct BuiltinFunctionRegistry {
|
||||||
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
||||||
|
|
@ -237,6 +239,7 @@ struct BuiltinFunctionRegistry {
|
||||||
loadSource(aten_ops, "aten");
|
loadSource(aten_ops, "aten");
|
||||||
loadSource(aten_ops_additional, "aten");
|
loadSource(aten_ops_additional, "aten");
|
||||||
|
|
||||||
|
#if !ENABLE_UPGRADERS
|
||||||
// Loads functions implementing historic behavior, see note [Versioned
|
// Loads functions implementing historic behavior, see note [Versioned
|
||||||
// Symbols]
|
// Symbols]
|
||||||
// Note: these functions go into the "upgraders" namespace
|
// Note: these functions go into the "upgraders" namespace
|
||||||
|
|
@ -249,6 +252,7 @@ struct BuiltinFunctionRegistry {
|
||||||
loadSource(div__scalar, "upgraders");
|
loadSource(div__scalar, "upgraders");
|
||||||
loadSource(full, "upgraders");
|
loadSource(full, "upgraders");
|
||||||
loadSource(full_out, "upgraders");
|
loadSource(full_out, "upgraders");
|
||||||
|
#endif
|
||||||
|
|
||||||
// These are under `prim` instead of `aten` since they exist to bind certain
|
// These are under `prim` instead of `aten` since they exist to bind certain
|
||||||
// tensor property getters to correpsonding methods
|
// tensor property getters to correpsonding methods
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/StringUtil.h>
|
#include <c10/util/StringUtil.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
|
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
|
||||||
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
|
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
|
||||||
|
|
@ -22,6 +23,7 @@
|
||||||
#include <torch/csrc/jit/passes/lift_closures.h>
|
#include <torch/csrc/jit/passes/lift_closures.h>
|
||||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||||
#include <torch/csrc/jit/passes/normalize_ops.h>
|
#include <torch/csrc/jit/passes/normalize_ops.h>
|
||||||
|
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
|
||||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
||||||
|
|
@ -656,6 +658,13 @@ struct to_ir {
|
||||||
}
|
}
|
||||||
method.setSchema(emitDef(def, self, graph->block()));
|
method.setSchema(emitDef(def, self, graph->block()));
|
||||||
|
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
// At this point, we might have received a graph that is compiled with
|
||||||
|
// old operator schemas that might not exist in the system anymore.
|
||||||
|
// Therefore, we replace such ops with its' valid upgrader.
|
||||||
|
ReplaceOldOperatorsWithUpgraders(graph);
|
||||||
|
#endif
|
||||||
|
|
||||||
// NB ORDERING: SSA conversion has to occur before
|
// NB ORDERING: SSA conversion has to occur before
|
||||||
// lifting of closures and forks, this way closures are converted
|
// lifting of closures and forks, this way closures are converted
|
||||||
// to SSA while part of their original graph, and closures are ready to
|
// to SSA while part of their original graph, and closures are ready to
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
||||||
#include <torch/csrc/jit/frontend/error_report.h>
|
#include <torch/csrc/jit/frontend/error_report.h>
|
||||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||||
|
|
@ -469,7 +470,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
|
||||||
}
|
}
|
||||||
|
|
||||||
// construct the full name of the schema for easier look up
|
// construct the full name of the schema for easier look up
|
||||||
auto schema_name = schema.operator_name().name + "." + schema.overload_name();
|
auto schema_name = getFullSchemaName(schema);
|
||||||
|
|
||||||
return MatchedSchema{
|
return MatchedSchema{
|
||||||
std::move(positional_inputs),
|
std::move(positional_inputs),
|
||||||
|
|
@ -619,6 +620,13 @@ static Value* emitBuiltinNode(
|
||||||
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
|
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
|
||||||
|
if (schema.overload_name() != "") {
|
||||||
|
return schema.operator_name().name + "." + schema.overload_name();
|
||||||
|
}
|
||||||
|
return schema.operator_name().name;
|
||||||
|
}
|
||||||
|
|
||||||
// Search for operators matching the provided symbol name and input types.
|
// Search for operators matching the provided symbol name and input types.
|
||||||
// If one is found, emit a node to the graph for that operator.
|
// If one is found, emit a node to the graph for that operator.
|
||||||
Value* emitBuiltinCall(
|
Value* emitBuiltinCall(
|
||||||
|
|
@ -631,8 +639,12 @@ Value* emitBuiltinCall(
|
||||||
const auto& variants = getAllOperatorsFor(name);
|
const auto& variants = getAllOperatorsFor(name);
|
||||||
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
|
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
|
||||||
|
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
// first let's set the graph's version
|
// first let's set the graph's version
|
||||||
auto graph_version = graph.get_op_version();
|
auto graph_version = graph.get_op_version();
|
||||||
|
#else
|
||||||
|
c10::optional<size_t> graph_version = c10::nullopt;
|
||||||
|
#endif
|
||||||
|
|
||||||
std::stringstream failure_messages;
|
std::stringstream failure_messages;
|
||||||
std::vector<const FunctionSchema*> schemas;
|
std::vector<const FunctionSchema*> schemas;
|
||||||
|
|
@ -643,9 +655,7 @@ Value* emitBuiltinCall(
|
||||||
schemas.reserve(variants.size());
|
schemas.reserve(variants.size());
|
||||||
for (const std::shared_ptr<Operator>& op : variants) {
|
for (const std::shared_ptr<Operator>& op : variants) {
|
||||||
bool found_upgrader = false;
|
bool found_upgrader = false;
|
||||||
auto op_overload_name = op->schema().overload_name();
|
auto op_name = getFullSchemaName(op->schema());
|
||||||
auto op_name = op->schema().operator_name().name +
|
|
||||||
((op_overload_name != "") ? "." + op_overload_name : "");
|
|
||||||
if (graph_version.has_value()) {
|
if (graph_version.has_value()) {
|
||||||
auto version_entry = get_operator_version_map().find(op_name);
|
auto version_entry = get_operator_version_map().find(op_name);
|
||||||
if (version_entry != get_operator_version_map().end()) {
|
if (version_entry != get_operator_version_map().end()) {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ TORCH_API bool convertibleToList(
|
||||||
const TypePtr& type,
|
const TypePtr& type,
|
||||||
const TypePtr& list_type_);
|
const TypePtr& list_type_);
|
||||||
|
|
||||||
|
TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema);
|
||||||
|
|
||||||
TORCH_API Value* emitBuiltinCall(
|
TORCH_API Value* emitBuiltinCall(
|
||||||
const SourceRange& loc,
|
const SourceRange& loc,
|
||||||
Graph& graph,
|
Graph& graph,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include <ATen/core/symbol.h>
|
#include <ATen/core/symbol.h>
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/frontend/error_report.h>
|
#include <torch/csrc/jit/frontend/error_report.h>
|
||||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||||
|
|
@ -319,12 +320,14 @@ struct TORCH_API BuiltinModule : public SugaredValue {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sym = Symbol::fromQualString(name + "::" + field);
|
auto sym = Symbol::fromQualString(name + "::" + field);
|
||||||
|
#if !ENABLE_UPGRADERS
|
||||||
if (version.has_value()) {
|
if (version.has_value()) {
|
||||||
// Possibly replaces symbol with another that implements its
|
// Possibly replaces symbol with another that implements its
|
||||||
// historic behavior.
|
// historic behavior.
|
||||||
// See note [Versioned Symbols]
|
// See note [Versioned Symbols]
|
||||||
sym = get_symbol_for_version(sym, *version);
|
sym = get_symbol_for_version(sym, *version);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
|
return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
||||||
|
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/api/include/torch/jit.h>
|
#include <torch/csrc/api/include/torch/jit.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
#if !ENABLE_UPGRADERS
|
||||||
// Note [Versioned Symbols]
|
// Note [Versioned Symbols]
|
||||||
// When the schema or behavior of a symbol changes, serialized Torchscript
|
// When the schema or behavior of a symbol changes, serialized Torchscript
|
||||||
// programs using that symbol are likely to break. To prevent those breaks,
|
// programs using that symbol are likely to break. To prevent those breaks,
|
||||||
|
|
@ -105,6 +106,7 @@ uint64_t get_min_version_for_kind(const NodeKind& kind) {
|
||||||
|
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
|
|
||||||
|
|
@ -7,7 +8,7 @@
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
#if !ENABLE_UPGRADERS
|
||||||
// Maps the given symbol into an implementation of its behavior at the
|
// Maps the given symbol into an implementation of its behavior at the
|
||||||
// given version.
|
// given version.
|
||||||
// See note [Versioned Symbols]
|
// See note [Versioned Symbols]
|
||||||
|
|
@ -17,6 +18,6 @@ get_symbol_for_version(const Symbol name, const uint64_t version);
|
||||||
// Maps the given kind to the minimum version that supports it.
|
// Maps the given kind to the minimum version that supports it.
|
||||||
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
|
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
|
||||||
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
|
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
|
||||||
|
#endif
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/ir/irparser.h>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
@ -9,12 +12,13 @@ namespace jit {
|
||||||
static UpgradersMap upgradersMap;
|
static UpgradersMap upgradersMap;
|
||||||
|
|
||||||
void UpgradersMap::set_content(
|
void UpgradersMap::set_content(
|
||||||
std::unordered_map<std::string, std::string>&& content) {
|
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
|
||||||
// make sure we populate the map only once
|
// make sure we populate the map only once
|
||||||
std::lock_guard<std::mutex> _(lock);
|
std::lock_guard<std::mutex> _(lock);
|
||||||
if (isPopulated) {
|
if (isPopulated) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
content_ = std::move(content);
|
content_ = std::move(content);
|
||||||
isPopulated = true;
|
isPopulated = true;
|
||||||
}
|
}
|
||||||
|
|
@ -24,7 +28,12 @@ int UpgradersMap::count() {
|
||||||
return content_.size();
|
return content_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::string>& UpgradersMap::
|
bool UpgradersMap::is_populated() {
|
||||||
|
std::lock_guard<std::mutex> _(lock);
|
||||||
|
return isPopulated;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
|
||||||
get_content() {
|
get_content() {
|
||||||
std::lock_guard<std::mutex> _(lock);
|
std::lock_guard<std::mutex> _(lock);
|
||||||
return content_;
|
return content_;
|
||||||
|
|
@ -34,7 +43,9 @@ void UpgradersMap::test_only_set_content(
|
||||||
const std::unordered_map<std::string, std::string>& content) {
|
const std::unordered_map<std::string, std::string>& content) {
|
||||||
std::lock_guard<std::mutex> _(lock);
|
std::lock_guard<std::mutex> _(lock);
|
||||||
for (const auto& entry : content) {
|
for (const auto& entry : content) {
|
||||||
content_.insert(entry);
|
auto graph = std::make_shared<Graph>();
|
||||||
|
torch::jit::parseIR(entry.second, graph.get());
|
||||||
|
content_.insert(std::make_pair(entry.first, graph));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void UpgradersMap::test_only_remove_content(
|
void UpgradersMap::test_only_remove_content(
|
||||||
|
|
@ -46,7 +57,7 @@ void UpgradersMap::test_only_remove_content(
|
||||||
}
|
}
|
||||||
|
|
||||||
void populate_upgraders_map(
|
void populate_upgraders_map(
|
||||||
std::unordered_map<std::string, std::string>&& content) {
|
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
|
||||||
upgradersMap.set_content(std::move(content));
|
upgradersMap.set_content(std::move(content));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -54,7 +65,12 @@ int get_upgraders_map_size() {
|
||||||
return upgradersMap.count();
|
return upgradersMap.count();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::string>& dump_upgraders_map() {
|
bool is_upgraders_map_populated() {
|
||||||
|
return upgradersMap.is_populated();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
|
||||||
|
dump_upgraders_map() {
|
||||||
return upgradersMap.get_content();
|
return upgradersMap.get_content();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -10,9 +11,11 @@ namespace jit {
|
||||||
|
|
||||||
class UpgradersMap {
|
class UpgradersMap {
|
||||||
public:
|
public:
|
||||||
void set_content(std::unordered_map<std::string, std::string>&& content);
|
void set_content(
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
|
||||||
int count();
|
int count();
|
||||||
const std::unordered_map<std::string, std::string>& get_content();
|
const std::unordered_map<std::string, std::shared_ptr<Graph>>& get_content();
|
||||||
|
bool is_populated();
|
||||||
// THESE METHODS ARE ONLY USED FOR TESTING PURPOSES
|
// THESE METHODS ARE ONLY USED FOR TESTING PURPOSES
|
||||||
void test_only_set_content(
|
void test_only_set_content(
|
||||||
const std::unordered_map<std::string, std::string>& content);
|
const std::unordered_map<std::string, std::string>& content);
|
||||||
|
|
@ -20,17 +23,19 @@ class UpgradersMap {
|
||||||
const std::unordered_map<std::string, std::string>& content);
|
const std::unordered_map<std::string, std::string>& content);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, std::string> content_;
|
std::unordered_map<std::string, std::shared_ptr<Graph>> content_;
|
||||||
std::mutex lock;
|
std::mutex lock;
|
||||||
bool isPopulated = false;
|
bool isPopulated = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API void populate_upgraders_map(
|
TORCH_API void populate_upgraders_map(
|
||||||
std::unordered_map<std::string, std::string>&& content);
|
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
|
||||||
|
|
||||||
TORCH_API int get_upgraders_map_size();
|
TORCH_API int get_upgraders_map_size();
|
||||||
|
|
||||||
TORCH_API const std::unordered_map<std::string, std::string>&
|
TORCH_API bool is_upgraders_map_populated();
|
||||||
|
|
||||||
|
TORCH_API const std::unordered_map<std::string, std::shared_ptr<Graph>>&
|
||||||
dump_upgraders_map();
|
dump_upgraders_map();
|
||||||
|
|
||||||
// THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING
|
// THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING
|
||||||
|
|
|
||||||
78
torch/csrc/jit/operator_upgraders/upgraders_entry.cpp
Normal file
78
torch/csrc/jit/operator_upgraders/upgraders_entry.cpp
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||||
|
|
||||||
|
#include <ATen/core/stack.h>
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
static std::unordered_map<std::string, std::string> kUpgradersEntryMap(
|
||||||
|
{{"div_Tensor_0_3", R"SCRIPT(
|
||||||
|
def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||||
|
if (self.is_floating_point() or other.is_floating_point()):
|
||||||
|
return self.true_divide(other)
|
||||||
|
return self.divide(other, rounding_mode='trunc')
|
||||||
|
)SCRIPT"},
|
||||||
|
{"div_Scalar_0_3", R"SCRIPT(
|
||||||
|
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||||
|
if (self.is_floating_point() or isinstance(other, float)):
|
||||||
|
return self.true_divide(other)
|
||||||
|
return self.divide(other, rounding_mode='trunc')
|
||||||
|
)SCRIPT"},
|
||||||
|
{"div_out_0_3", R"SCRIPT(
|
||||||
|
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
|
||||||
|
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
|
||||||
|
return self.true_divide(other, out=out)
|
||||||
|
return self.divide(other, rounding_mode='trunc', out=out)
|
||||||
|
)SCRIPT"},
|
||||||
|
{"div__Tensor_0_3", R"SCRIPT(
|
||||||
|
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||||
|
if (self.is_floating_point() or other.is_floating_point()):
|
||||||
|
return self.true_divide_(other)
|
||||||
|
return self.divide_(other, rounding_mode='trunc')
|
||||||
|
)SCRIPT"},
|
||||||
|
{"div__Scalar_0_3", R"SCRIPT(
|
||||||
|
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||||
|
if (self.is_floating_point() or isinstance(other, float)):
|
||||||
|
return self.true_divide_(other)
|
||||||
|
return self.divide_(other, rounding_mode='trunc')
|
||||||
|
)SCRIPT"},
|
||||||
|
{"full_0_4", R"SCRIPT(
|
||||||
|
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
||||||
|
layout:Optional[int]=None, device:Optional[Device]=None,
|
||||||
|
pin_memory:Optional[bool]=None) -> Tensor:
|
||||||
|
if dtype is None:
|
||||||
|
fill_value = float(fill_value)
|
||||||
|
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
||||||
|
)SCRIPT"},
|
||||||
|
{"full_out_0_4", R"SCRIPT(
|
||||||
|
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
||||||
|
return torch.full(size, fill_value, out=out)
|
||||||
|
)SCRIPT"}});
|
||||||
|
|
||||||
|
using UpgraderMap = std::unordered_map<std::string, std::shared_ptr<Graph>>;
|
||||||
|
void populate_upgraders_graph_map() {
|
||||||
|
if (!is_upgraders_map_populated()) {
|
||||||
|
UpgraderMap populate_content;
|
||||||
|
for (const auto& entry : kUpgradersEntryMap) {
|
||||||
|
auto cu = std::make_shared<CompilationUnit>();
|
||||||
|
cu->define(c10::nullopt, entry.second, nativeResolver(), nullptr);
|
||||||
|
Function& jitFunc = cu->get_function(entry.first);
|
||||||
|
GraphFunction& graphFunction = toGraphFunction(jitFunc);
|
||||||
|
populate_content.insert(
|
||||||
|
std::make_pair(entry.first, graphFunction.graph()));
|
||||||
|
}
|
||||||
|
|
||||||
|
populate_upgraders_map(std::forward<UpgraderMap>(populate_content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
14
torch/csrc/jit/operator_upgraders/upgraders_entry.h
Normal file
14
torch/csrc/jit/operator_upgraders/upgraders_entry.h
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
#pragma once
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
TORCH_API void populate_upgraders_graph_map();
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
12
torch/csrc/jit/operator_upgraders/upgraders_guard.cpp
Normal file
12
torch/csrc/jit/operator_upgraders/upgraders_guard.cpp
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
bool is_upgraders_enabled() {
|
||||||
|
return ENABLE_UPGRADERS;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
10
torch/csrc/jit/operator_upgraders/upgraders_guard.h
Normal file
10
torch/csrc/jit/operator_upgraders/upgraders_guard.h
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
#pragma once
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
TORCH_API bool is_upgraders_enabled();
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -31,11 +31,11 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
|
||||||
{"aten::div_.Tensor",
|
{"aten::div_.Tensor",
|
||||||
{{4,
|
{{4,
|
||||||
"div__Tensor_0_3",
|
"div__Tensor_0_3",
|
||||||
"aten::div_.Tensor(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
|
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
|
||||||
{"aten::div_.Scalar",
|
{"aten::div_.Scalar",
|
||||||
{{4,
|
{{4,
|
||||||
"div__Scalar_0_3",
|
"div__Scalar_0_3",
|
||||||
"aten::div_.Scalar(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
|
"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
|
||||||
{"aten::full",
|
{"aten::full",
|
||||||
{{5,
|
{{5,
|
||||||
"full_0_4",
|
"full_0_4",
|
||||||
|
|
|
||||||
|
|
@ -15,29 +15,6 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
static std::unordered_map<std::string, std::shared_ptr<Graph>> upgraderCache;
|
|
||||||
|
|
||||||
std::shared_ptr<Graph> getUpgraderGraph(const std::string& upgrader_name) {
|
|
||||||
auto it = upgraderCache.find(upgrader_name);
|
|
||||||
if (it != upgraderCache.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
upgrader_graph_entry != dump_upgraders_map().end(),
|
|
||||||
"Corresponding upgrader graph for ",
|
|
||||||
upgrader_name,
|
|
||||||
" must exist.",
|
|
||||||
" This upgrader"
|
|
||||||
" might be deprecated.");
|
|
||||||
|
|
||||||
auto upgrader_graph = std::make_shared<Graph>();
|
|
||||||
parseIR(upgrader_graph_entry->second, upgrader_graph.get());
|
|
||||||
upgraderCache[upgrader_name] = upgrader_graph;
|
|
||||||
return upgrader_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct OldOpsReplacerWithUpgraders {
|
struct OldOpsReplacerWithUpgraders {
|
||||||
OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
|
OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
|
||||||
: graph_(std::move(graph)) {}
|
: graph_(std::move(graph)) {}
|
||||||
|
|
@ -52,9 +29,7 @@ struct OldOpsReplacerWithUpgraders {
|
||||||
Node* node = graph_it.next();
|
Node* node = graph_it.next();
|
||||||
while (node) {
|
while (node) {
|
||||||
if (auto schema = node->maybeSchema()) {
|
if (auto schema = node->maybeSchema()) {
|
||||||
auto schema_name = schema->name() +
|
auto schema_name = getFullSchemaName(*schema);
|
||||||
(schema->overload_name() != "" ? "." + schema->overload_name()
|
|
||||||
: "");
|
|
||||||
// this implies there was a version bump because of this operator
|
// this implies there was a version bump because of this operator
|
||||||
auto version_entry = get_operator_version_map().find(schema_name);
|
auto version_entry = get_operator_version_map().find(schema_name);
|
||||||
if (version_entry != get_operator_version_map().end()) {
|
if (version_entry != get_operator_version_map().end()) {
|
||||||
|
|
@ -74,8 +49,16 @@ struct OldOpsReplacerWithUpgraders {
|
||||||
}
|
}
|
||||||
auto upgrader_entry_val = upgrader_entry.value();
|
auto upgrader_entry_val = upgrader_entry.value();
|
||||||
auto upgrader_name = upgrader_entry_val.upgrader_name;
|
auto upgrader_name = upgrader_entry_val.upgrader_name;
|
||||||
auto upgrader_graph = getUpgraderGraph(upgrader_name);
|
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
upgrader_graph_entry != dump_upgraders_map().end(),
|
||||||
|
"Corresponding upgrader graph for ",
|
||||||
|
upgrader_name,
|
||||||
|
" must exist.",
|
||||||
|
" This upgrader"
|
||||||
|
" might be deprecated.");
|
||||||
|
|
||||||
|
auto upgrader_graph = upgrader_graph_entry->second;
|
||||||
// inline the upgrader function body
|
// inline the upgrader function body
|
||||||
WithInsertPoint guard(node);
|
WithInsertPoint guard(node);
|
||||||
auto new_outputs = insertGraph(
|
auto new_outputs = insertGraph(
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||||
#include <torch/csrc/jit/mobile/module.h>
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
|
||||||
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||||
#include <torch/csrc/jit/python/module_python.h>
|
#include <torch/csrc/jit/python/module_python.h>
|
||||||
#include <torch/csrc/jit/python/python_ivalue.h>
|
#include <torch/csrc/jit/python/python_ivalue.h>
|
||||||
|
|
@ -1731,7 +1732,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
return Decl(p.parseTypeComment());
|
return Decl(p.parseTypeComment());
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("_populate_upgraders_map", &populate_upgraders_map);
|
m.def("_is_upgraders_enabled", &is_upgraders_enabled);
|
||||||
|
|
||||||
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
|
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
|
||||||
m.def("_dump_upgraders_map", &dump_upgraders_map);
|
m.def("_dump_upgraders_map", &dump_upgraders_map);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||||
#include <torch/csrc/jit/serialization/import_read.h>
|
#include <torch/csrc/jit/serialization/import_read.h>
|
||||||
#include <torch/csrc/jit/serialization/import_source.h>
|
#include <torch/csrc/jit/serialization/import_source.h>
|
||||||
|
|
@ -20,6 +21,7 @@
|
||||||
#include <caffe2/serialize/file_adapter.h>
|
#include <caffe2/serialize/file_adapter.h>
|
||||||
#include <caffe2/serialize/inline_container.h>
|
#include <caffe2/serialize/inline_container.h>
|
||||||
#include <caffe2/serialize/istream_adapter.h>
|
#include <caffe2/serialize/istream_adapter.h>
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
@ -239,6 +241,10 @@ graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_ze
|
||||||
Module ScriptModuleDeserializer::deserialize(
|
Module ScriptModuleDeserializer::deserialize(
|
||||||
c10::optional<at::Device> device,
|
c10::optional<at::Device> device,
|
||||||
ExtraFilesMap& extra_files) {
|
ExtraFilesMap& extra_files) {
|
||||||
|
// we populate the upgraders map before any load starts
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
populate_upgraders_graph_map();
|
||||||
|
#endif
|
||||||
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
||||||
device_ = device;
|
device_ = device;
|
||||||
// Load extra files.
|
// Load extra files.
|
||||||
|
|
|
||||||
|
|
@ -577,7 +577,9 @@ void SourceImporterImpl::importClass(
|
||||||
/*propResolvers=*/{},
|
/*propResolvers=*/{},
|
||||||
methods,
|
methods,
|
||||||
method_resolvers,
|
method_resolvers,
|
||||||
&self);
|
&self,
|
||||||
|
/*shouldMangle=*/false,
|
||||||
|
/*operator_set_version=*/version_);
|
||||||
cu_->define_hooks(
|
cu_->define_hooks(
|
||||||
qualified_classname,
|
qualified_classname,
|
||||||
hooks,
|
hooks,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/StringUtil.h>
|
#include <c10/util/StringUtil.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <caffe2/serialize/versions.h>
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/frontend/error_report.h>
|
#include <torch/csrc/jit/frontend/error_report.h>
|
||||||
|
|
@ -13,6 +14,7 @@
|
||||||
#include <torch/csrc/jit/ir/attributes.h>
|
#include <torch/csrc/jit/ir/attributes.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/jit/ir/ir_views.h>
|
#include <torch/csrc/jit/ir/ir_views.h>
|
||||||
|
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||||
#include <torch/csrc/jit/resource_guard.h>
|
#include <torch/csrc/jit/resource_guard.h>
|
||||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||||
|
|
||||||
|
|
@ -741,9 +743,22 @@ struct PythonPrintImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void checkVersion(const Node* const node) {
|
void checkVersion(Node* node) {
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
if (auto schema = node->maybeSchema()) {
|
||||||
|
auto schema_name = getFullSchemaName(*schema);
|
||||||
|
auto version_entry = get_operator_version_map().find(schema_name);
|
||||||
|
if (version_entry != get_operator_version_map().end()) {
|
||||||
|
const auto& entry = version_entry->second;
|
||||||
|
// TODO (tugsuu) move this calculation into a seperate step.
|
||||||
|
min_version_ = std::max(
|
||||||
|
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
min_version_ =
|
min_version_ =
|
||||||
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void printNode(Node* node, bool print_const) {
|
void printNode(Node* node, bool print_const) {
|
||||||
|
|
@ -1594,7 +1609,11 @@ struct PythonPrintImpl {
|
||||||
bool enforce_importable_;
|
bool enforce_importable_;
|
||||||
|
|
||||||
// The least version that supports all printed ops
|
// The least version that supports all printed ops
|
||||||
|
#if ENABLE_UPGRADERS
|
||||||
|
uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
|
||||||
|
#else
|
||||||
uint64_t min_version_ = 0;
|
uint64_t min_version_ = 0;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
PythonPrint::PythonPrint(
|
PythonPrint::PythonPrint(
|
||||||
|
|
|
||||||
|
|
@ -147,11 +147,6 @@ def load(f, map_location=None, _extra_files=None):
|
||||||
os.remove("scriptmodule.pt")
|
os.remove("scriptmodule.pt")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO (tugsuu) Putting this on top level creates
|
|
||||||
# circular import issue. We need to fix this later
|
|
||||||
from torch.jit.operator_upgraders import populate_upgraders_map
|
|
||||||
populate_upgraders_map()
|
|
||||||
|
|
||||||
if isinstance(f, string_classes):
|
if isinstance(f, string_classes):
|
||||||
if not os.path.exists(f): # type: ignore[type-var]
|
if not os.path.exists(f): # type: ignore[type-var]
|
||||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||||
|
|
|
||||||
|
|
@ -116,12 +116,5 @@ def generate_bytecode() -> List:
|
||||||
yaml_content.append(entry)
|
yaml_content.append(entry)
|
||||||
return yaml_content
|
return yaml_content
|
||||||
|
|
||||||
def populate_upgraders_map():
|
|
||||||
upgrader_set = collect_available_upgraders()
|
|
||||||
content = {}
|
|
||||||
for upgrader_name in upgrader_set:
|
|
||||||
content[upgrader_name] = str(globals()[upgrader_name].graph)
|
|
||||||
torch._C._populate_upgraders_map(content)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
raise RuntimeError("This file is not meant to be run directly")
|
raise RuntimeError("This file is not meant to be run directly")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user