mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Bump version number to 7 and compile old operators with old schema (#68358)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68358 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33433730 Pulled By: tugsbayasgalan fbshipit-source-id: 202c58365bae13195d3545cefcb0da9162b02151
This commit is contained in:
parent
8bdbe94344
commit
b0fdca8855
|
|
@ -164,7 +164,13 @@ class TORCH_API PyTorchStreamWriter final {
|
|||
std::string padding_;
|
||||
std::ofstream file_stream_;
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
// This number will be updated when the model has operators
|
||||
// that have valid upgraders.
|
||||
#if ENABLE_UPGRADERS
|
||||
uint64_t version_ = kMinProducedFileFormatVersion;
|
||||
#else
|
||||
uint64_t version_ = kProducedFileFormatVersion;
|
||||
#endif
|
||||
bool finalized_ = false;
|
||||
bool err_seen_ = false;
|
||||
friend size_t ostream_write_func(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
// Flag that controls if we want to enable upgraders
|
||||
// in the server side. When this flag is set to False,
|
||||
// it will switch to old dynamic versioning approach
|
||||
#define ENABLE_UPGRADERS false
|
||||
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L;
|
||||
#else
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
#endif
|
||||
|
||||
// Versions (i.e. why was the version number bumped?)
|
||||
|
||||
|
|
@ -47,7 +56,23 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
|||
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
|
||||
// when given bool or integer fill values.
|
||||
// 6. Write version string to `./data/version` instead of `version`.
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
// This is set to 7 from 3 due to a different interpretation of what
|
||||
// file format version is. Whenever there is new upgrader introduced,
|
||||
// this number should be bumped.
|
||||
// 1. aten::div is changed at version 4
|
||||
// 2. aten::full is changed at version 5
|
||||
// 3. torch.package uses version 6
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x7L;
|
||||
#else
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
#endif
|
||||
|
||||
// Absolute minimum version we will write packages. This
|
||||
// means that every package from now on will always be
|
||||
// greater than this number.
|
||||
constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
|
||||
|
||||
// The version we write when the archive contains bytecode.
|
||||
// It must be higher or eq to kProducedFileFormatVersion.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ set(JIT_TEST_SRCS
|
|||
${JIT_TEST_ROOT}/test_alias_analysis.cpp
|
||||
${JIT_TEST_ROOT}/test_argument_spec.cpp
|
||||
${JIT_TEST_ROOT}/test_autodiff.cpp
|
||||
${JIT_TEST_ROOT}/test_load_upgraders.cpp
|
||||
${JIT_TEST_ROOT}/test_op_replacement.cpp
|
||||
${JIT_TEST_ROOT}/test_upgrader_utils.cpp
|
||||
${JIT_TEST_ROOT}/test_backend.cpp
|
||||
|
|
|
|||
45
test/cpp/jit/test_load_upgraders.cpp
Normal file
45
test/cpp/jit/test_load_upgraders.cpp
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#include <caffe2/serialize/versions.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
// Basic tests to check if C++ torch::jit::load
|
||||
// can load the upgraders fine
|
||||
// TODO (tugsuu) add more tests
|
||||
TEST(UpgraderLoad, CanPopulateUpgradersGraph) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x: Tensor):
|
||||
b = 5
|
||||
return torch.div(x, b)
|
||||
)");
|
||||
std::stringstream ms;
|
||||
m.save(ms);
|
||||
auto loaded_m = torch::jit::load(ms);
|
||||
auto version_map = get_operator_version_map();
|
||||
auto upgraders = dump_upgraders_map();
|
||||
|
||||
for (const auto& entry : version_map) {
|
||||
auto list_of_upgraders_for_op = entry.second;
|
||||
for (const auto& upgrader_entry : list_of_upgraders_for_op) {
|
||||
EXPECT_TRUE(
|
||||
upgraders.find(upgrader_entry.upgrader_name) != upgraders.end());
|
||||
}
|
||||
}
|
||||
|
||||
auto test_graph = loaded_m.get_method("forward").graph();
|
||||
// should have saved with version 4, so it is still up to date
|
||||
testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
553
test/jit/test_legacy_upgraders.py
Normal file
553
test/jit/test_legacy_upgraders.py
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from itertools import product as product
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
# Legacy test cases for dynamic operator versioning
|
||||
class TestLegacyUpgraders(JitTestCase):
|
||||
reason: str = "Skipped due to new global operator versioning for operators"
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_symbols(self):
|
||||
"""
|
||||
Tests Torchscript symbol versioning. See note [Versioned Symbols].
|
||||
This test uses an undocumented, test-only function
|
||||
torch._test_serialization_subcmul.
|
||||
This function is implemented as (a - alpha * b) with a default value
|
||||
of 1 for alpha. In file format version 2, however, it was implemented
|
||||
as (b - alpha * a) with a default value of 2 for alpha.
|
||||
This test verifies a module seralized with file format version 2
|
||||
exhibits the old behavior, and that the same module newly serialized
|
||||
exhibits the current behavior.
|
||||
"""
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a, b, alpha: float):
|
||||
no_alpha = torch._test_serialization_subcmul(a, b)
|
||||
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
|
||||
return no_alpha, with_alpha
|
||||
|
||||
def historic_subcmul(a, b, alpha=2):
|
||||
return b - alpha * a
|
||||
|
||||
def current_subcmul(a, b, alpha=1):
|
||||
return a - alpha * b
|
||||
|
||||
# Loads and verifies the historic behavior of the module
|
||||
# that was serialized with version 2
|
||||
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
||||
a = torch.randn((5,))
|
||||
b = torch.randn((5,))
|
||||
alpha = random.random()
|
||||
args = (a, b, alpha)
|
||||
no_alpha_v2, with_alpha_v2 = module_v2(*args)
|
||||
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
|
||||
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
|
||||
|
||||
# Scripts, saves, loads and verifies the current behavior of the module
|
||||
scripted_module = torch.jit.script(MyModule())
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_module, buffer)
|
||||
buffer.seek(0)
|
||||
module_current = torch.jit.load(buffer)
|
||||
no_alpha_current, with_alpha_current = module_current(*args)
|
||||
self.assertEqual(no_alpha_current, current_subcmul(a, b))
|
||||
self.assertEqual(with_alpha_current, current_subcmul(*args))
|
||||
|
||||
# Helper that returns the module after saving and loading
|
||||
def _save_load_module(self, m):
|
||||
scripted_module = torch.jit.script(m())
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_module, buffer)
|
||||
buffer.seek(0)
|
||||
return torch.jit.load(buffer)
|
||||
|
||||
# Helper which returns the result of a function or the exception the
|
||||
# function threw.
|
||||
def _try_fn(self, fn, *args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
def _verify_no(self, kind, m):
|
||||
self._verify_count(kind, m, 0)
|
||||
|
||||
def _verify_count(self, kind, m, count):
|
||||
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
|
||||
self.assertEqual(node_count, count)
|
||||
|
||||
"""
|
||||
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
||||
to call either aten::true_divide(_), if an input is a float type,
|
||||
or truncated aten::divide(_) otherwise.
|
||||
NOTE: currently compares against current div behavior, too, since
|
||||
div behavior has not yet been updated.
|
||||
"""
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_tensor(self):
|
||||
def historic_div(self, other):
|
||||
if self.is_floating_point() or other.is_floating_point():
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
|
||||
# Tensor x Tensor
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
result_0 = a / b
|
||||
result_1 = torch.div(a, b)
|
||||
result_2 = a.div(b)
|
||||
|
||||
return result_0, result_1, result_2
|
||||
|
||||
# Loads historic module
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 3)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = torch.tensor((val_b,))
|
||||
|
||||
def _helper(m, fn):
|
||||
m_results = self._try_fn(m, a, b)
|
||||
fn_result = self._try_fn(fn, a, b)
|
||||
|
||||
if isinstance(m_results, Exception):
|
||||
self.assertTrue(isinstance(fn_result, Exception))
|
||||
else:
|
||||
for result in m_results:
|
||||
self.assertEqual(result, fn_result)
|
||||
|
||||
_helper(v3_module, historic_div)
|
||||
_helper(current_module, torch.div)
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_tensor_inplace(self):
|
||||
def historic_div_(self, other):
|
||||
if self.is_floating_point() or other.is_floating_point():
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
a /= b
|
||||
return a
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 1)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = torch.tensor((val_b,))
|
||||
|
||||
def _helper(m, fn):
|
||||
fn_result = self._try_fn(fn, a.clone(), b)
|
||||
m_result = self._try_fn(m, a, b)
|
||||
|
||||
if isinstance(m_result, Exception):
|
||||
self.assertTrue(fn_result, Exception)
|
||||
else:
|
||||
self.assertEqual(m_result, fn_result)
|
||||
self.assertEqual(m_result, a)
|
||||
|
||||
_helper(v3_module, historic_div_)
|
||||
|
||||
# Recreates a since it was modified in place
|
||||
a = torch.tensor((val_a,))
|
||||
_helper(current_module, torch.Tensor.div_)
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_tensor_out(self):
|
||||
def historic_div_out(self, other, out):
|
||||
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
|
||||
return torch.true_divide(self, other, out=out)
|
||||
return torch.divide(self, other, out=out, rounding_mode='trunc')
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a, b, out):
|
||||
return a.div(b, out=out)
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 1)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = torch.tensor((val_b,))
|
||||
|
||||
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
|
||||
def _helper(m, fn):
|
||||
fn_result = None
|
||||
if fn is torch.div:
|
||||
fn_result = self._try_fn(fn, a, b, out=out.clone())
|
||||
else:
|
||||
fn_result = self._try_fn(fn, a, b, out.clone())
|
||||
m_result = self._try_fn(m, a, b, out)
|
||||
|
||||
if isinstance(m_result, Exception):
|
||||
self.assertTrue(fn_result, Exception)
|
||||
else:
|
||||
self.assertEqual(m_result, fn_result)
|
||||
self.assertEqual(m_result, out)
|
||||
|
||||
_helper(v3_module, historic_div_out)
|
||||
_helper(current_module, torch.div)
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_scalar(self):
|
||||
def historic_div_scalar_float(self, other: float):
|
||||
return torch.true_divide(self, other)
|
||||
|
||||
def historic_div_scalar_int(self, other: int):
|
||||
if self.is_floating_point():
|
||||
return torch.true_divide(self, other)
|
||||
return torch.divide(self, other, rounding_mode='trunc')
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleFloat, self).__init__()
|
||||
|
||||
def forward(self, a, b: float):
|
||||
return a / b
|
||||
|
||||
class MyModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleInt, self).__init__()
|
||||
|
||||
def forward(self, a, b: int):
|
||||
return a / b
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
for m in (v3_module_float, v3_module_int):
|
||||
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = val_b
|
||||
|
||||
def _helper(m, fn):
|
||||
m_result = self._try_fn(m, a, b)
|
||||
fn_result = self._try_fn(fn, a, b)
|
||||
|
||||
if isinstance(m_result, Exception):
|
||||
self.assertTrue(fn_result, Exception)
|
||||
else:
|
||||
self.assertEqual(m_result, fn_result)
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float)
|
||||
_helper(current_module_float, torch.div)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int)
|
||||
_helper(current_module_int, torch.div)
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_scalar_reciprocal(self):
|
||||
def historic_div_scalar_float_reciprocal(self, other: float):
|
||||
return other / self
|
||||
|
||||
def historic_div_scalar_int_reciprocal(self, other: int):
|
||||
if self.is_floating_point():
|
||||
return other / self
|
||||
return torch.divide(other, self, rounding_mode='trunc')
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleFloat, self).__init__()
|
||||
|
||||
def forward(self, a, b: float):
|
||||
return b / a
|
||||
|
||||
class MyModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleInt, self).__init__()
|
||||
|
||||
def forward(self, a, b: int):
|
||||
return b / a
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
|
||||
# so true_divide and floor_divide do not appear in their graphs
|
||||
for m in (v3_module_float, v3_module_int):
|
||||
self._verify_no("aten::div", m)
|
||||
self._verify_no("aten::true_divide", m)
|
||||
self._verify_no("aten::floor_divide", m)
|
||||
self._verify_count("aten::reciprocal", m, 1)
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = val_b
|
||||
|
||||
def _helper(m, fn):
|
||||
m_result = self._try_fn(m, a, b)
|
||||
fn_result = None
|
||||
# Reverses argument order for torch.div
|
||||
if fn is torch.div:
|
||||
fn_result = self._try_fn(torch.div, b, a)
|
||||
else:
|
||||
fn_result = self._try_fn(fn, a, b)
|
||||
|
||||
if isinstance(m_result, Exception):
|
||||
self.assertTrue(isinstance(fn_result, Exception))
|
||||
elif fn is torch.div or a.is_floating_point():
|
||||
self.assertEqual(m_result, fn_result)
|
||||
else:
|
||||
# Skip when fn is not torch.div and a is integral because
|
||||
# historic_div_scalar_int performs floored division
|
||||
pass
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
|
||||
_helper(current_module_float, torch.div)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
|
||||
_helper(current_module_int, torch.div)
|
||||
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_scalar_inplace(self):
|
||||
def historic_div_scalar_float_inplace(self, other: float):
|
||||
return self.true_divide_(other)
|
||||
|
||||
def historic_div_scalar_int_inplace(self, other: int):
|
||||
if self.is_floating_point():
|
||||
return self.true_divide_(other)
|
||||
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleFloat, self).__init__()
|
||||
|
||||
def forward(self, a, b: float):
|
||||
a /= b
|
||||
return a
|
||||
|
||||
class MyModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModuleInt, self).__init__()
|
||||
|
||||
def forward(self, a, b: int):
|
||||
a /= b
|
||||
return a
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
for m in (v3_module_float, v3_module_int):
|
||||
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
vals = (2., 3., 2, 3)
|
||||
for val_a, val_b in product(vals, vals):
|
||||
a = torch.tensor((val_a,))
|
||||
b = val_b
|
||||
|
||||
def _helper(m, fn):
|
||||
m_result = self._try_fn(m, a, b)
|
||||
fn_result = self._try_fn(fn, a, b)
|
||||
|
||||
if isinstance(m_result, Exception):
|
||||
self.assertTrue(fn_result, Exception)
|
||||
else:
|
||||
self.assertEqual(m_result, fn_result)
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float_inplace)
|
||||
_helper(current_module_float, torch.Tensor.div_)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int_inplace)
|
||||
_helper(current_module_int, torch.Tensor.div_)
|
||||
|
||||
# NOTE: Scalar division was already true division in op version 3,
|
||||
# so this test verifies the behavior is unchanged.
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_div_scalar_scalar(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a: float, b: int, c: float, d: int):
|
||||
result_0 = a / b
|
||||
result_1 = a / c
|
||||
result_2 = b / c
|
||||
result_3 = b / d
|
||||
return (result_0, result_1, result_2, result_3)
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 4)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 4)
|
||||
|
||||
def _helper(m, fn):
|
||||
vals = (5., 3, 2., 7)
|
||||
m_result = m(*vals)
|
||||
fn_result = fn(*vals)
|
||||
for mr, hr in zip(m_result, fn_result):
|
||||
self.assertEqual(mr, hr)
|
||||
|
||||
_helper(v3_module, current_module)
|
||||
|
||||
# NOTE: the JIT was incapable of handling boolean fill values when
|
||||
# PyTorch produced file format versions 0-4
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_full_integer_value(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, int_fill: int):
|
||||
size = torch.Size(2, 2)
|
||||
a = torch.full(size, int_fill)
|
||||
b = torch.full(size, 1)
|
||||
return (a, b)
|
||||
|
||||
try:
|
||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::full", v4_module, 2)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::full", current_module, 2)
|
||||
|
||||
# Verifies historic integer type inference is float
|
||||
# NOTE: only verifies floating point, not exact dtype, due to
|
||||
# https://github.com/pytorch/pytorch/issues/40470
|
||||
results = v4_module(2)
|
||||
for result in results:
|
||||
self.assertTrue(result.is_floating_point())
|
||||
|
||||
# Verifies values are correct
|
||||
a, b = results
|
||||
self.assertTrue((a == 2.).all())
|
||||
self.assertTrue((b == 1.).all())
|
||||
|
||||
# Tests that torch.full behavior which is the same from prior versions
|
||||
# to version 5 is preserved.
|
||||
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
|
||||
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
|
||||
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
|
||||
def test_versioned_full_preserved(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, float_fill: float):
|
||||
size = (2, 2)
|
||||
a = torch.full(size, 1.)
|
||||
b = torch.full(size, float_fill)
|
||||
c = torch.full(size, float_fill, dtype=torch.long)
|
||||
|
||||
out = torch.empty(size, dtype=torch.long)
|
||||
d = torch.full(size, float_fill, out=out)
|
||||
|
||||
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
|
||||
layout=torch.strided, device='cpu')
|
||||
return (a, b, c, d, e)
|
||||
|
||||
try:
|
||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::full", v4_module, 5)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::full", current_module, 5)
|
||||
|
||||
self.assertEqual(v4_module(2.), current_module(2.))
|
||||
|
|
@ -24,23 +24,6 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
class TestSaveLoad(JitTestCase):
|
||||
|
||||
def test_versioned_symbols_reserialization(self):
|
||||
"""
|
||||
Tests that loading and saving serialized Torchscript with a versioned
|
||||
symbol won't persist the original function and will inline the
|
||||
versioned builtin.
|
||||
"""
|
||||
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(module_v2, buffer)
|
||||
buffer.seek(0)
|
||||
module_reserialized = torch.jit.load(buffer)
|
||||
|
||||
subcmul_nodes = sum("subcmul" in n.kind() for
|
||||
n in module_reserialized.graph.nodes())
|
||||
self.assertEqual(subcmul_nodes, 0)
|
||||
|
||||
def test_different_modules(self):
|
||||
"""
|
||||
Exercise the situation where we have the same qualified name
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
from itertools import product as product
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import example, settings, given
|
||||
|
|
@ -24,55 +23,6 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
class TestSaveLoadForOpVersion(JitTestCase):
|
||||
def test_versioned_symbols(self):
|
||||
"""
|
||||
Tests Torchscript symbol versioning. See note [Versioned Symbols].
|
||||
This test uses an undocumented, test-only function
|
||||
torch._test_serialization_subcmul.
|
||||
|
||||
This function is implemented as (a - alpha * b) with a default value
|
||||
of 1 for alpha. In file format version 2, however, it was implemented
|
||||
as (b - alpha * a) with a default value of 2 for alpha.
|
||||
This test verifies a module seralized with file format version 2
|
||||
exhibits the old behavior, and that the same module newly serialized
|
||||
exhibits the current behavior.
|
||||
"""
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, a, b, alpha: float):
|
||||
no_alpha = torch._test_serialization_subcmul(a, b)
|
||||
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
|
||||
return no_alpha, with_alpha
|
||||
|
||||
def historic_subcmul(a, b, alpha=2):
|
||||
return b - alpha * a
|
||||
|
||||
def current_subcmul(a, b, alpha=1):
|
||||
return a - alpha * b
|
||||
|
||||
# Loads and verifies the historic behavior of the module
|
||||
# that was serialized with version 2
|
||||
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
||||
a = torch.randn((5,))
|
||||
b = torch.randn((5,))
|
||||
alpha = random.random()
|
||||
args = (a, b, alpha)
|
||||
no_alpha_v2, with_alpha_v2 = module_v2(*args)
|
||||
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
|
||||
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
|
||||
|
||||
# Scripts, saves, loads and verifies the current behavior of the module
|
||||
scripted_module = torch.jit.script(MyModule())
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_module, buffer)
|
||||
buffer.seek(0)
|
||||
module_current = torch.jit.load(buffer)
|
||||
no_alpha_current, with_alpha_current = module_current(*args)
|
||||
self.assertEqual(no_alpha_current, current_subcmul(a, b))
|
||||
self.assertEqual(with_alpha_current, current_subcmul(*args))
|
||||
|
||||
# Helper that returns the module after saving and loading
|
||||
def _save_load_module(self, m):
|
||||
scripted_module = torch.jit.script(m())
|
||||
|
|
@ -107,7 +57,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
||||
to call either aten::true_divide(_), if an input is a float type,
|
||||
or truncated aten::divide(_) otherwise.
|
||||
|
||||
NOTE: currently compares against current div behavior, too, since
|
||||
div behavior has not yet been updated.
|
||||
"""
|
||||
|
|
@ -137,18 +86,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
|
||||
# Loads historic module
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 3)
|
||||
|
||||
for val_a, val_b in product(sample_input, sample_input):
|
||||
a = torch.tensor((val_a,))
|
||||
|
|
@ -164,9 +107,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
for result in m_results:
|
||||
self.assertEqual(result, fn_result)
|
||||
|
||||
_helper(v3_module, historic_div)
|
||||
_helper(v3_mobile_module, historic_div)
|
||||
_helper(current_module, torch.div)
|
||||
_helper(current_mobile_module, torch.div)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
|
|
@ -189,18 +130,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return a
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 1)
|
||||
|
||||
for val_a, val_b in product(sample_input, sample_input):
|
||||
a = torch.tensor((val_a,))
|
||||
|
|
@ -215,12 +150,10 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
self.assertEqual(m_result, fn_result)
|
||||
self.assertEqual(m_result, a)
|
||||
|
||||
_helper(v3_module, historic_div_)
|
||||
_helper(v3_mobile_module, historic_div_)
|
||||
|
||||
# Recreates a since it was modified in place
|
||||
a = torch.tensor((val_a,))
|
||||
_helper(current_module, torch.Tensor.div_)
|
||||
_helper(current_mobile_module, torch.Tensor.div_)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
|
|
@ -242,18 +175,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return a.div(b, out=out)
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 1)
|
||||
|
||||
for val_a, val_b in product(sample_input, sample_input):
|
||||
a = torch.tensor((val_a,))
|
||||
|
|
@ -274,8 +201,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
self.assertEqual(m_result, fn_result)
|
||||
self.assertEqual(m_result, out)
|
||||
|
||||
_helper(v3_module, historic_div_out)
|
||||
_helper(current_module, torch.div)
|
||||
_helper(v3_mobile_module, historic_div_out)
|
||||
_helper(current_mobile_module, torch.div)
|
||||
|
||||
|
|
@ -308,8 +233,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return a / b
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v3.pt")
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
|
|
@ -321,14 +244,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
||||
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
for val_a, val_b in product(sample_input, sample_input):
|
||||
a = torch.tensor((val_a,))
|
||||
b = val_b
|
||||
|
|
@ -343,13 +261,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
self.assertEqual(m_result, fn_result)
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float)
|
||||
_helper(current_module_float, torch.div)
|
||||
_helper(v3_mobile_module_float, current_mobile_module_float)
|
||||
_helper(current_mobile_module_float, torch.div)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int)
|
||||
_helper(current_module_int, torch.div)
|
||||
_helper(v3_mobile_module_int, historic_div_scalar_int)
|
||||
_helper(current_mobile_module_int, torch.div)
|
||||
|
||||
|
|
@ -382,8 +296,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return b / a
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
|
|
@ -391,17 +303,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
|
||||
# so true_divide and floor_divide do not appear in their graphs
|
||||
for m in (v3_module_float, v3_module_int):
|
||||
self._verify_no("aten::div", m)
|
||||
self._verify_no("aten::true_divide", m)
|
||||
self._verify_no("aten::floor_divide", m)
|
||||
self._verify_count("aten::reciprocal", m, 1)
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
|
||||
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
|
||||
|
||||
|
|
@ -428,13 +329,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
pass
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
|
||||
_helper(current_module_float, torch.div)
|
||||
_helper(v3_mobile_module_float, current_mobile_module_float)
|
||||
_helper(current_mobile_module_float, torch.div)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
|
||||
_helper(current_module_int, torch.div)
|
||||
_helper(v3_mobile_module_int, current_mobile_module_int)
|
||||
_helper(current_mobile_module_int, torch.div)
|
||||
|
||||
|
|
@ -470,9 +367,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return a
|
||||
|
||||
try:
|
||||
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
|
||||
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
|
||||
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
|
|
@ -480,22 +374,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
for m in (v3_module_float, v3_module_int):
|
||||
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
|
||||
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
||||
|
||||
current_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
current_mobile_module_float = self._save_load_module(MyModuleFloat)
|
||||
current_mobile_module_int = self._save_load_module(MyModuleInt)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
for m in (current_module_float, current_module_int):
|
||||
self._verify_count("aten::div", m, 1)
|
||||
|
||||
for val_a, val_b in product(sample_input, sample_input):
|
||||
a = torch.tensor((val_a,))
|
||||
b = val_b
|
||||
|
|
@ -510,12 +391,8 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
self.assertEqual(m_result, fn_result)
|
||||
|
||||
if isinstance(b, float):
|
||||
_helper(v3_module_float, historic_div_scalar_float_inplace)
|
||||
_helper(current_module_float, torch.Tensor.div_)
|
||||
_helper(current_mobile_module_float, torch.Tensor.div_)
|
||||
else:
|
||||
_helper(v3_module_int, historic_div_scalar_int_inplace)
|
||||
_helper(current_module_int, torch.Tensor.div_)
|
||||
_helper(current_mobile_module_int, torch.Tensor.div_)
|
||||
|
||||
# NOTE: Scalar division was already true division in op version 3,
|
||||
|
|
@ -533,17 +410,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
return (result_0, result_1, result_2, result_3)
|
||||
|
||||
try:
|
||||
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::div", v3_module, 4)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||
self._verify_count("aten::div", current_module, 4)
|
||||
|
||||
def _helper(m, fn):
|
||||
vals = (5., 3, 2., 7)
|
||||
|
|
@ -552,74 +424,4 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
|||
for mr, hr in zip(m_result, fn_result):
|
||||
self.assertEqual(mr, hr)
|
||||
|
||||
_helper(v3_module, current_module)
|
||||
_helper(v3_mobile_module, current_mobile_module)
|
||||
|
||||
# NOTE: the JIT was incapable of handling boolean fill values when
|
||||
# PyTorch produced file format versions 0-4
|
||||
def test_versioned_full_integer_value(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, int_fill: int):
|
||||
size = torch.Size(2, 2)
|
||||
a = torch.full(size, int_fill)
|
||||
b = torch.full(size, 1)
|
||||
return (a, b)
|
||||
|
||||
try:
|
||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::full", v4_module, 2)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::full", current_module, 2)
|
||||
|
||||
# Verifies historic integer type inference is float
|
||||
# NOTE: only verifies floating point, not exact dtype, due to
|
||||
# https://github.com/pytorch/pytorch/issues/40470
|
||||
results = v4_module(2)
|
||||
for result in results:
|
||||
self.assertTrue(result.is_floating_point())
|
||||
|
||||
# Verifies values are correct
|
||||
a, b = results
|
||||
self.assertTrue((a == 2.).all())
|
||||
self.assertTrue((b == 1.).all())
|
||||
|
||||
# Tests that torch.full behavior which is the same from prior versions
|
||||
# to version 5 is preserved.
|
||||
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
|
||||
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
|
||||
def test_versioned_full_preserved(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, float_fill: float):
|
||||
size = (2, 2)
|
||||
a = torch.full(size, 1.)
|
||||
b = torch.full(size, float_fill)
|
||||
c = torch.full(size, float_fill, dtype=torch.long)
|
||||
|
||||
out = torch.empty(size, dtype=torch.long)
|
||||
d = torch.full(size, float_fill, out=out)
|
||||
|
||||
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
|
||||
layout=torch.strided, device='cpu')
|
||||
return (a, b, c, d, e)
|
||||
|
||||
try:
|
||||
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
self._verify_count("aten::full", v4_module, 5)
|
||||
|
||||
current_module = self._save_load_module(MyModule)
|
||||
self._verify_count("aten::full", current_module, 5)
|
||||
|
||||
self.assertEqual(v4_module(2.), current_module(2.))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ import io
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import unittest
|
||||
import zipfile
|
||||
from torch.testing import FileCheck
|
||||
from torch._C import _is_upgraders_enabled
|
||||
from typing import Union
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
|
@ -16,6 +21,15 @@ if __name__ == '__main__':
|
|||
"instead.")
|
||||
|
||||
class TestUpgraders(JitTestCase):
|
||||
def _load_model_version(self, loaded_model):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
zipped_model = zipfile.ZipFile(buffer)
|
||||
version = int(zipped_model.read('archive/version').decode("utf-8"))
|
||||
return version
|
||||
|
||||
# TODO (tugsuu) We should ideally be generating this test cases.
|
||||
def test_populated_upgrader_graph(self):
|
||||
@torch.jit.script
|
||||
def f():
|
||||
|
|
@ -67,7 +81,7 @@ class TestUpgraders(JitTestCase):
|
|||
# upgrader map should have populated now
|
||||
upgraders_size = torch._C._get_upgraders_map_size()
|
||||
|
||||
test_map = {"a": "b", "c": "d"}
|
||||
test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
|
||||
torch._C._test_only_populate_upgraders(test_map)
|
||||
upgraders_size_after_test = torch._C._get_upgraders_map_size()
|
||||
self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
|
||||
|
|
@ -81,3 +95,117 @@ class TestUpgraders(JitTestCase):
|
|||
upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
|
||||
self.assertTrue("a" not in upgraders_dump_after_remove_test)
|
||||
self.assertTrue("c" not in upgraders_dump_after_remove_test)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_div_tensor_at_3(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
# there are 3 aten::div in this model
|
||||
# And the upgrader for aten::div uses two
|
||||
# div's because of if/else branch
|
||||
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 4)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
# we check by its code because graph variable names
|
||||
# can be different every time
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_test_serialization(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
|
||||
|
||||
# add test version entry to the version map
|
||||
upgrader_bumped_version = 3
|
||||
upgrader_name = "_test_serialization_subcmul_0_2"
|
||||
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
|
||||
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
|
||||
|
||||
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
|
||||
|
||||
# add test upgrader in the upgraders map
|
||||
@torch.jit.script
|
||||
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
|
||||
return other - (self * alpha)
|
||||
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||
|
||||
# test if the server is able to find the test upgraders and apply to IR
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 3)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
# we check by its' code because graph variable names
|
||||
# can be different every time
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
|
||||
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_div_scalar_at_3(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 4)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
|
||||
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
|
||||
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_div_tensor_out_at_3(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 4)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
# we check by its' code because graph variable names
|
||||
# can be different every time
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_full_at_4(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(loaded_model, buffer)
|
||||
buffer.seek(0)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 5)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
# we check by its' code because graph variable names
|
||||
# can be different every time
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_full_out_at_4(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 5)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401
|
|||
from jit.test_aten_pow import TestAtenPow # noqa: F401
|
||||
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
|
||||
from jit.test_union import TestUnion # noqa: F401
|
||||
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
|
||||
from jit.test_models import MnistNet
|
||||
from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ core_sources_common = [
|
|||
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||
"torch/csrc/jit/mobile/runtime_compatibility.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/version_map.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders_guard.cpp",
|
||||
"torch/csrc/jit/runtime/instruction.cpp",
|
||||
"torch/csrc/jit/runtime/jit_exception.cpp",
|
||||
"torch/csrc/jit/runtime/operator.cpp",
|
||||
|
|
@ -121,7 +122,6 @@ core_sources_common = [
|
|||
"torch/csrc/jit/runtime/vararg_functions.cpp",
|
||||
"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
|
||||
"torch/csrc/jit/mobile/prim_ops_registery.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
||||
"torch/csrc/profiler/util.cpp",
|
||||
]
|
||||
|
||||
|
|
@ -209,6 +209,8 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/mobile/nnc/context.cpp",
|
||||
"torch/csrc/jit/mobile/nnc/registry.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/utils.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
||||
"torch/csrc/jit/passes/annotate_warns.cpp",
|
||||
"torch/csrc/jit/passes/bailout_graph.cpp",
|
||||
"torch/csrc/jit/passes/batch_mm.cpp",
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ def _create_function_from_trace(
|
|||
def _jit_is_script_object(obj: Any) -> _bool: ...
|
||||
def _last_executed_optimized_graph() -> Graph: ...
|
||||
def parse_type_comment(comment: str) -> Decl: ...
|
||||
def _populate_upgraders_map(content: Dict[str, str]) -> None: ...
|
||||
def _is_upgraders_enabled() -> _bool: ...
|
||||
def _get_upgraders_map_size() -> _int: ...
|
||||
def _dump_upgraders_map() -> Dict[str, str]: ...
|
||||
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
||||
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/api/include/torch/jit.h>
|
||||
#include <torch/csrc/jit/frontend/code_template.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
|
|
@ -85,6 +86,7 @@ def __contains__(self: str, key: str):
|
|||
return self.find(key, 0, len(self)) != -1
|
||||
)SCRIPT";
|
||||
|
||||
#if !ENABLE_UPGRADERS
|
||||
// Implementations of historic symbol behaviors are defined here
|
||||
// See note [Versioned Symbols]
|
||||
|
||||
|
|
@ -158,7 +160,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
|||
pin_memory:Optional[bool]=None) -> Tensor:
|
||||
if dtype is None:
|
||||
fill_value = float(fill_value)
|
||||
|
||||
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
||||
)SCRIPT";
|
||||
|
||||
|
|
@ -168,6 +169,7 @@ auto full_out = R"SCRIPT(
|
|||
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
||||
return torch.full(size, fill_value, out=out)
|
||||
)SCRIPT";
|
||||
#endif
|
||||
|
||||
struct BuiltinFunctionRegistry {
|
||||
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
||||
|
|
@ -237,6 +239,7 @@ struct BuiltinFunctionRegistry {
|
|||
loadSource(aten_ops, "aten");
|
||||
loadSource(aten_ops_additional, "aten");
|
||||
|
||||
#if !ENABLE_UPGRADERS
|
||||
// Loads functions implementing historic behavior, see note [Versioned
|
||||
// Symbols]
|
||||
// Note: these functions go into the "upgraders" namespace
|
||||
|
|
@ -249,6 +252,7 @@ struct BuiltinFunctionRegistry {
|
|||
loadSource(div__scalar, "upgraders");
|
||||
loadSource(full, "upgraders");
|
||||
loadSource(full_out, "upgraders");
|
||||
#endif
|
||||
|
||||
// These are under `prim` instead of `aten` since they exist to bind certain
|
||||
// tensor property getters to correpsonding methods
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
|
||||
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
|
||||
|
|
@ -22,6 +23,7 @@
|
|||
#include <torch/csrc/jit/passes/lift_closures.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/passes/normalize_ops.h>
|
||||
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
||||
|
|
@ -656,6 +658,13 @@ struct to_ir {
|
|||
}
|
||||
method.setSchema(emitDef(def, self, graph->block()));
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
// At this point, we might have received a graph that is compiled with
|
||||
// old operator schemas that might not exist in the system anymore.
|
||||
// Therefore, we replace such ops with its' valid upgrader.
|
||||
ReplaceOldOperatorsWithUpgraders(graph);
|
||||
#endif
|
||||
|
||||
// NB ORDERING: SSA conversion has to occur before
|
||||
// lifting of closures and forks, this way closures are converted
|
||||
// to SSA while part of their original graph, and closures are ready to
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
|
@ -469,7 +470,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
|
|||
}
|
||||
|
||||
// construct the full name of the schema for easier look up
|
||||
auto schema_name = schema.operator_name().name + "." + schema.overload_name();
|
||||
auto schema_name = getFullSchemaName(schema);
|
||||
|
||||
return MatchedSchema{
|
||||
std::move(positional_inputs),
|
||||
|
|
@ -619,6 +620,13 @@ static Value* emitBuiltinNode(
|
|||
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
|
||||
}
|
||||
|
||||
std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
|
||||
if (schema.overload_name() != "") {
|
||||
return schema.operator_name().name + "." + schema.overload_name();
|
||||
}
|
||||
return schema.operator_name().name;
|
||||
}
|
||||
|
||||
// Search for operators matching the provided symbol name and input types.
|
||||
// If one is found, emit a node to the graph for that operator.
|
||||
Value* emitBuiltinCall(
|
||||
|
|
@ -631,8 +639,12 @@ Value* emitBuiltinCall(
|
|||
const auto& variants = getAllOperatorsFor(name);
|
||||
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
// first let's set the graph's version
|
||||
auto graph_version = graph.get_op_version();
|
||||
#else
|
||||
c10::optional<size_t> graph_version = c10::nullopt;
|
||||
#endif
|
||||
|
||||
std::stringstream failure_messages;
|
||||
std::vector<const FunctionSchema*> schemas;
|
||||
|
|
@ -643,9 +655,7 @@ Value* emitBuiltinCall(
|
|||
schemas.reserve(variants.size());
|
||||
for (const std::shared_ptr<Operator>& op : variants) {
|
||||
bool found_upgrader = false;
|
||||
auto op_overload_name = op->schema().overload_name();
|
||||
auto op_name = op->schema().operator_name().name +
|
||||
((op_overload_name != "") ? "." + op_overload_name : "");
|
||||
auto op_name = getFullSchemaName(op->schema());
|
||||
if (graph_version.has_value()) {
|
||||
auto version_entry = get_operator_version_map().find(op_name);
|
||||
if (version_entry != get_operator_version_map().end()) {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ TORCH_API bool convertibleToList(
|
|||
const TypePtr& type,
|
||||
const TypePtr& list_type_);
|
||||
|
||||
TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema);
|
||||
|
||||
TORCH_API Value* emitBuiltinCall(
|
||||
const SourceRange& loc,
|
||||
Graph& graph,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <utility>
|
||||
|
||||
#include <ATen/core/symbol.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
|
|
@ -319,12 +320,14 @@ struct TORCH_API BuiltinModule : public SugaredValue {
|
|||
}
|
||||
|
||||
auto sym = Symbol::fromQualString(name + "::" + field);
|
||||
#if !ENABLE_UPGRADERS
|
||||
if (version.has_value()) {
|
||||
// Possibly replaces symbol with another that implements its
|
||||
// historic behavior.
|
||||
// See note [Versioned Symbols]
|
||||
sym = get_symbol_for_version(sym, *version);
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
||||
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/api/include/torch/jit.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#if !ENABLE_UPGRADERS
|
||||
// Note [Versioned Symbols]
|
||||
// When the schema or behavior of a symbol changes, serialized Torchscript
|
||||
// programs using that symbol are likely to break. To prevent those breaks,
|
||||
|
|
@ -105,6 +106,7 @@ uint64_t get_min_version_for_kind(const NodeKind& kind) {
|
|||
|
||||
return it->second;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
|
||||
|
|
@ -7,7 +8,7 @@
|
|||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#if !ENABLE_UPGRADERS
|
||||
// Maps the given symbol into an implementation of its behavior at the
|
||||
// given version.
|
||||
// See note [Versioned Symbols]
|
||||
|
|
@ -17,6 +18,6 @@ get_symbol_for_version(const Symbol name, const uint64_t version);
|
|||
// Maps the given kind to the minimum version that supports it.
|
||||
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
|
||||
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
|
||||
|
||||
#endif
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
|
@ -9,12 +12,13 @@ namespace jit {
|
|||
static UpgradersMap upgradersMap;
|
||||
|
||||
void UpgradersMap::set_content(
|
||||
std::unordered_map<std::string, std::string>&& content) {
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
|
||||
// make sure we populate the map only once
|
||||
std::lock_guard<std::mutex> _(lock);
|
||||
if (isPopulated) {
|
||||
return;
|
||||
}
|
||||
|
||||
content_ = std::move(content);
|
||||
isPopulated = true;
|
||||
}
|
||||
|
|
@ -24,7 +28,12 @@ int UpgradersMap::count() {
|
|||
return content_.size();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::string>& UpgradersMap::
|
||||
bool UpgradersMap::is_populated() {
|
||||
std::lock_guard<std::mutex> _(lock);
|
||||
return isPopulated;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
|
||||
get_content() {
|
||||
std::lock_guard<std::mutex> _(lock);
|
||||
return content_;
|
||||
|
|
@ -34,7 +43,9 @@ void UpgradersMap::test_only_set_content(
|
|||
const std::unordered_map<std::string, std::string>& content) {
|
||||
std::lock_guard<std::mutex> _(lock);
|
||||
for (const auto& entry : content) {
|
||||
content_.insert(entry);
|
||||
auto graph = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(entry.second, graph.get());
|
||||
content_.insert(std::make_pair(entry.first, graph));
|
||||
}
|
||||
}
|
||||
void UpgradersMap::test_only_remove_content(
|
||||
|
|
@ -46,7 +57,7 @@ void UpgradersMap::test_only_remove_content(
|
|||
}
|
||||
|
||||
void populate_upgraders_map(
|
||||
std::unordered_map<std::string, std::string>&& content) {
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
|
||||
upgradersMap.set_content(std::move(content));
|
||||
}
|
||||
|
||||
|
|
@ -54,7 +65,12 @@ int get_upgraders_map_size() {
|
|||
return upgradersMap.count();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::string>& dump_upgraders_map() {
|
||||
bool is_upgraders_map_populated() {
|
||||
return upgradersMap.is_populated();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
|
||||
dump_upgraders_map() {
|
||||
return upgradersMap.get_content();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
|
@ -10,9 +11,11 @@ namespace jit {
|
|||
|
||||
class UpgradersMap {
|
||||
public:
|
||||
void set_content(std::unordered_map<std::string, std::string>&& content);
|
||||
void set_content(
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
|
||||
int count();
|
||||
const std::unordered_map<std::string, std::string>& get_content();
|
||||
const std::unordered_map<std::string, std::shared_ptr<Graph>>& get_content();
|
||||
bool is_populated();
|
||||
// THESE METHODS ARE ONLY USED FOR TESTING PURPOSES
|
||||
void test_only_set_content(
|
||||
const std::unordered_map<std::string, std::string>& content);
|
||||
|
|
@ -20,17 +23,19 @@ class UpgradersMap {
|
|||
const std::unordered_map<std::string, std::string>& content);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> content_;
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>> content_;
|
||||
std::mutex lock;
|
||||
bool isPopulated = false;
|
||||
};
|
||||
|
||||
TORCH_API void populate_upgraders_map(
|
||||
std::unordered_map<std::string, std::string>&& content);
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
|
||||
|
||||
TORCH_API int get_upgraders_map_size();
|
||||
|
||||
TORCH_API const std::unordered_map<std::string, std::string>&
|
||||
TORCH_API bool is_upgraders_map_populated();
|
||||
|
||||
TORCH_API const std::unordered_map<std::string, std::shared_ptr<Graph>>&
|
||||
dump_upgraders_map();
|
||||
|
||||
// THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING
|
||||
|
|
|
|||
78
torch/csrc/jit/operator_upgraders/upgraders_entry.cpp
Normal file
78
torch/csrc/jit/operator_upgraders/upgraders_entry.cpp
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static std::unordered_map<std::string, std::string> kUpgradersEntryMap(
|
||||
{{"div_Tensor_0_3", R"SCRIPT(
|
||||
def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point()):
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div_Scalar_0_3", R"SCRIPT(
|
||||
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||
if (self.is_floating_point() or isinstance(other, float)):
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div_out_0_3", R"SCRIPT(
|
||||
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
|
||||
return self.true_divide(other, out=out)
|
||||
return self.divide(other, rounding_mode='trunc', out=out)
|
||||
)SCRIPT"},
|
||||
{"div__Tensor_0_3", R"SCRIPT(
|
||||
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point()):
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div__Scalar_0_3", R"SCRIPT(
|
||||
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||
if (self.is_floating_point() or isinstance(other, float)):
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"full_0_4", R"SCRIPT(
|
||||
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
||||
layout:Optional[int]=None, device:Optional[Device]=None,
|
||||
pin_memory:Optional[bool]=None) -> Tensor:
|
||||
if dtype is None:
|
||||
fill_value = float(fill_value)
|
||||
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
||||
)SCRIPT"},
|
||||
{"full_out_0_4", R"SCRIPT(
|
||||
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
||||
return torch.full(size, fill_value, out=out)
|
||||
)SCRIPT"}});
|
||||
|
||||
using UpgraderMap = std::unordered_map<std::string, std::shared_ptr<Graph>>;
|
||||
void populate_upgraders_graph_map() {
|
||||
if (!is_upgraders_map_populated()) {
|
||||
UpgraderMap populate_content;
|
||||
for (const auto& entry : kUpgradersEntryMap) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
cu->define(c10::nullopt, entry.second, nativeResolver(), nullptr);
|
||||
Function& jitFunc = cu->get_function(entry.first);
|
||||
GraphFunction& graphFunction = toGraphFunction(jitFunc);
|
||||
populate_content.insert(
|
||||
std::make_pair(entry.first, graphFunction.graph()));
|
||||
}
|
||||
|
||||
populate_upgraders_map(std::forward<UpgraderMap>(populate_content));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
14
torch/csrc/jit/operator_upgraders/upgraders_entry.h
Normal file
14
torch/csrc/jit/operator_upgraders/upgraders_entry.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
#include <c10/macros/Export.h>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API void populate_upgraders_graph_map();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
12
torch/csrc/jit/operator_upgraders/upgraders_guard.cpp
Normal file
12
torch/csrc/jit/operator_upgraders/upgraders_guard.cpp
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
bool is_upgraders_enabled() {
|
||||
return ENABLE_UPGRADERS;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
10
torch/csrc/jit/operator_upgraders/upgraders_guard.h
Normal file
10
torch/csrc/jit/operator_upgraders/upgraders_guard.h
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API bool is_upgraders_enabled();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -31,11 +31,11 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
|
|||
{"aten::div_.Tensor",
|
||||
{{4,
|
||||
"div__Tensor_0_3",
|
||||
"aten::div_.Tensor(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
|
||||
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
|
||||
{"aten::div_.Scalar",
|
||||
{{4,
|
||||
"div__Scalar_0_3",
|
||||
"aten::div_.Scalar(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
|
||||
"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
|
||||
{"aten::full",
|
||||
{{5,
|
||||
"full_0_4",
|
||||
|
|
|
|||
|
|
@ -15,29 +15,6 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static std::unordered_map<std::string, std::shared_ptr<Graph>> upgraderCache;
|
||||
|
||||
std::shared_ptr<Graph> getUpgraderGraph(const std::string& upgrader_name) {
|
||||
auto it = upgraderCache.find(upgrader_name);
|
||||
if (it != upgraderCache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
upgrader_graph_entry != dump_upgraders_map().end(),
|
||||
"Corresponding upgrader graph for ",
|
||||
upgrader_name,
|
||||
" must exist.",
|
||||
" This upgrader"
|
||||
" might be deprecated.");
|
||||
|
||||
auto upgrader_graph = std::make_shared<Graph>();
|
||||
parseIR(upgrader_graph_entry->second, upgrader_graph.get());
|
||||
upgraderCache[upgrader_name] = upgrader_graph;
|
||||
return upgrader_graph;
|
||||
}
|
||||
|
||||
struct OldOpsReplacerWithUpgraders {
|
||||
OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
|
||||
: graph_(std::move(graph)) {}
|
||||
|
|
@ -52,9 +29,7 @@ struct OldOpsReplacerWithUpgraders {
|
|||
Node* node = graph_it.next();
|
||||
while (node) {
|
||||
if (auto schema = node->maybeSchema()) {
|
||||
auto schema_name = schema->name() +
|
||||
(schema->overload_name() != "" ? "." + schema->overload_name()
|
||||
: "");
|
||||
auto schema_name = getFullSchemaName(*schema);
|
||||
// this implies there was a version bump because of this operator
|
||||
auto version_entry = get_operator_version_map().find(schema_name);
|
||||
if (version_entry != get_operator_version_map().end()) {
|
||||
|
|
@ -74,8 +49,16 @@ struct OldOpsReplacerWithUpgraders {
|
|||
}
|
||||
auto upgrader_entry_val = upgrader_entry.value();
|
||||
auto upgrader_name = upgrader_entry_val.upgrader_name;
|
||||
auto upgrader_graph = getUpgraderGraph(upgrader_name);
|
||||
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
upgrader_graph_entry != dump_upgraders_map().end(),
|
||||
"Corresponding upgrader graph for ",
|
||||
upgrader_name,
|
||||
" must exist.",
|
||||
" This upgrader"
|
||||
" might be deprecated.");
|
||||
|
||||
auto upgrader_graph = upgrader_graph_entry->second;
|
||||
// inline the upgrader function body
|
||||
WithInsertPoint guard(node);
|
||||
auto new_outputs = insertGraph(
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||
#include <torch/csrc/jit/python/module_python.h>
|
||||
#include <torch/csrc/jit/python/python_ivalue.h>
|
||||
|
|
@ -1731,7 +1732,8 @@ void initJitScriptBindings(PyObject* module) {
|
|||
return Decl(p.parseTypeComment());
|
||||
});
|
||||
|
||||
m.def("_populate_upgraders_map", &populate_upgraders_map);
|
||||
m.def("_is_upgraders_enabled", &is_upgraders_enabled);
|
||||
|
||||
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
|
||||
m.def("_dump_upgraders_map", &dump_upgraders_map);
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#endif
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/serialization/import_read.h>
|
||||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
|
|
@ -20,6 +21,7 @@
|
|||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/istream_adapter.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <fmt/format.h>
|
||||
|
|
@ -239,6 +241,10 @@ graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_ze
|
|||
Module ScriptModuleDeserializer::deserialize(
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
// we populate the upgraders map before any load starts
|
||||
#if ENABLE_UPGRADERS
|
||||
populate_upgraders_graph_map();
|
||||
#endif
|
||||
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
||||
device_ = device;
|
||||
// Load extra files.
|
||||
|
|
|
|||
|
|
@ -577,7 +577,9 @@ void SourceImporterImpl::importClass(
|
|||
/*propResolvers=*/{},
|
||||
methods,
|
||||
method_resolvers,
|
||||
&self);
|
||||
&self,
|
||||
/*shouldMangle=*/false,
|
||||
/*operator_set_version=*/version_);
|
||||
cu_->define_hooks(
|
||||
qualified_classname,
|
||||
hooks,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
#include <torch/csrc/jit/ir/attributes.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
|
||||
|
|
@ -741,9 +743,22 @@ struct PythonPrintImpl {
|
|||
}
|
||||
}
|
||||
|
||||
void checkVersion(const Node* const node) {
|
||||
void checkVersion(Node* node) {
|
||||
#if ENABLE_UPGRADERS
|
||||
if (auto schema = node->maybeSchema()) {
|
||||
auto schema_name = getFullSchemaName(*schema);
|
||||
auto version_entry = get_operator_version_map().find(schema_name);
|
||||
if (version_entry != get_operator_version_map().end()) {
|
||||
const auto& entry = version_entry->second;
|
||||
// TODO (tugsuu) move this calculation into a seperate step.
|
||||
min_version_ = std::max(
|
||||
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
|
||||
}
|
||||
}
|
||||
#else
|
||||
min_version_ =
|
||||
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
||||
#endif
|
||||
}
|
||||
|
||||
void printNode(Node* node, bool print_const) {
|
||||
|
|
@ -1594,7 +1609,11 @@ struct PythonPrintImpl {
|
|||
bool enforce_importable_;
|
||||
|
||||
// The least version that supports all printed ops
|
||||
#if ENABLE_UPGRADERS
|
||||
uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
|
||||
#else
|
||||
uint64_t min_version_ = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
PythonPrint::PythonPrint(
|
||||
|
|
|
|||
|
|
@ -147,11 +147,6 @@ def load(f, map_location=None, _extra_files=None):
|
|||
os.remove("scriptmodule.pt")
|
||||
"""
|
||||
|
||||
# TODO (tugsuu) Putting this on top level creates
|
||||
# circular import issue. We need to fix this later
|
||||
from torch.jit.operator_upgraders import populate_upgraders_map
|
||||
populate_upgraders_map()
|
||||
|
||||
if isinstance(f, string_classes):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||
|
|
|
|||
|
|
@ -116,12 +116,5 @@ def generate_bytecode() -> List:
|
|||
yaml_content.append(entry)
|
||||
return yaml_content
|
||||
|
||||
def populate_upgraders_map():
|
||||
upgrader_set = collect_available_upgraders()
|
||||
content = {}
|
||||
for upgrader_name in upgrader_set:
|
||||
content[upgrader_name] = str(globals()[upgrader_name].graph)
|
||||
torch._C._populate_upgraders_map(content)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError("This file is not meant to be run directly")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user