Bump version number to 7 and compile old operators with old schema (#68358)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68358

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33433730

Pulled By: tugsbayasgalan

fbshipit-source-id: 202c58365bae13195d3545cefcb0da9162b02151
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2022-01-05 23:55:49 -08:00 committed by Facebook GitHub Bot
parent 8bdbe94344
commit b0fdca8855
32 changed files with 993 additions and 281 deletions

View File

@ -164,7 +164,13 @@ class TORCH_API PyTorchStreamWriter final {
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
// This number will be updated when the model has operators
// that have valid upgraders.
#if ENABLE_UPGRADERS
uint64_t version_ = kMinProducedFileFormatVersion;
#else
uint64_t version_ = kProducedFileFormatVersion;
#endif
bool finalized_ = false;
bool err_seen_ = false;
friend size_t ostream_write_func(

View File

@ -1,12 +1,21 @@
#pragma once
#include <cstdint>
namespace caffe2 {
namespace serialize {
// Flag that controls if we want to enable upgraders
// in the server side. When this flag is set to False,
// it will switch to old dynamic versioning approach
#define ENABLE_UPGRADERS false
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
// Versions (i.e. why was the version number bumped?)
@ -47,7 +56,23 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
// when given bool or integer fill values.
// 6. Write version string to `./data/version` instead of `version`.
#if ENABLE_UPGRADERS
// This is set to 7 from 3 due to a different interpretation of what
// file format version is. Whenever there is new upgrader introduced,
// this number should be bumped.
// 1. aten::div is changed at version 4
// 2. aten::full is changed at version 5
// 3. torch.package uses version 6
constexpr uint64_t kProducedFileFormatVersion = 0x7L;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif
// Absolute minimum version we will write packages. This
// means that every package from now on will always be
// greater than this number.
constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
// The version we write when the archive contains bytecode.
// It must be higher or eq to kProducedFileFormatVersion.

View File

@ -42,6 +42,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_alias_analysis.cpp
${JIT_TEST_ROOT}/test_argument_spec.cpp
${JIT_TEST_ROOT}/test_autodiff.cpp
${JIT_TEST_ROOT}/test_load_upgraders.cpp
${JIT_TEST_ROOT}/test_op_replacement.cpp
${JIT_TEST_ROOT}/test_upgrader_utils.cpp
${JIT_TEST_ROOT}/test_backend.cpp

View File

@ -0,0 +1,45 @@
#include <caffe2/serialize/versions.h>
#include <gtest/gtest.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/serialization/import.h>
#include <test/cpp/jit/test_utils.h>
namespace torch {
namespace jit {
#if ENABLE_UPGRADERS
// Basic tests to check if C++ torch::jit::load
// can load the upgraders fine
// TODO (tugsuu) add more tests
TEST(UpgraderLoad, CanPopulateUpgradersGraph) {
Module m("m");
m.define(R"(
def forward(self, x: Tensor):
b = 5
return torch.div(x, b)
)");
std::stringstream ms;
m.save(ms);
auto loaded_m = torch::jit::load(ms);
auto version_map = get_operator_version_map();
auto upgraders = dump_upgraders_map();
for (const auto& entry : version_map) {
auto list_of_upgraders_for_op = entry.second;
for (const auto& upgrader_entry : list_of_upgraders_for_op) {
EXPECT_TRUE(
upgraders.find(upgrader_entry.upgrader_name) != upgraders.end());
}
}
auto test_graph = loaded_m.get_method("forward").graph();
// should have saved with version 4, so it is still up to date
testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph);
}
#endif
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,553 @@
# Owner(s): ["oncall: jit"]
from itertools import product as product
import io
import os
import random
import sys
import unittest
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Legacy test cases for dynamic operator versioning
class TestLegacyUpgraders(JitTestCase):
reason: str = "Skipped due to new global operator versioning for operators"
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_symbols(self):
"""
Tests Torchscript symbol versioning. See note [Versioned Symbols].
This test uses an undocumented, test-only function
torch._test_serialization_subcmul.
This function is implemented as (a - alpha * b) with a default value
of 1 for alpha. In file format version 2, however, it was implemented
as (b - alpha * a) with a default value of 2 for alpha.
This test verifies a module seralized with file format version 2
exhibits the old behavior, and that the same module newly serialized
exhibits the current behavior.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, alpha: float):
no_alpha = torch._test_serialization_subcmul(a, b)
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
return no_alpha, with_alpha
def historic_subcmul(a, b, alpha=2):
return b - alpha * a
def current_subcmul(a, b, alpha=1):
return a - alpha * b
# Loads and verifies the historic behavior of the module
# that was serialized with version 2
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
a = torch.randn((5,))
b = torch.randn((5,))
alpha = random.random()
args = (a, b, alpha)
no_alpha_v2, with_alpha_v2 = module_v2(*args)
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
# Scripts, saves, loads and verifies the current behavior of the module
scripted_module = torch.jit.script(MyModule())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
module_current = torch.jit.load(buffer)
no_alpha_current, with_alpha_current = module_current(*args)
self.assertEqual(no_alpha_current, current_subcmul(a, b))
self.assertEqual(with_alpha_current, current_subcmul(*args))
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
scripted_module = torch.jit.script(m())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
# Helper which returns the result of a function or the exception the
# function threw.
def _try_fn(self, fn, *args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return e
def _verify_no(self, kind, m):
self._verify_count(kind, m, 0)
def _verify_count(self, kind, m, count):
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count)
"""
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
or truncated aten::divide(_) otherwise.
NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated.
"""
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor(self):
def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
# Tensor x Tensor
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
result_0 = a / b
result_1 = torch.div(a, b)
result_2 = a.div(b)
return result_0, result_1, result_2
# Loads historic module
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 3)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
m_results = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_results, Exception):
self.assertTrue(isinstance(fn_result, Exception))
else:
for result in m_results:
self.assertEqual(result, fn_result)
_helper(v3_module, historic_div)
_helper(current_module, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor_inplace(self):
def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
a /= b
return a
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
fn_result = self._try_fn(fn, a.clone(), b)
m_result = self._try_fn(m, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, a)
_helper(v3_module, historic_div_)
# Recreates a since it was modified in place
a = torch.tensor((val_a,))
_helper(current_module, torch.Tensor.div_)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_tensor_out(self):
def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
return torch.true_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, out):
return a.div(b, out=out)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
def _helper(m, fn):
fn_result = None
if fn is torch.div:
fn_result = self._try_fn(fn, a, b, out=out.clone())
else:
fn_result = self._try_fn(fn, a, b, out.clone())
m_result = self._try_fn(m, a, b, out)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, out)
_helper(v3_module, historic_div_out)
_helper(current_module, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar(self):
def historic_div_scalar_float(self, other: float):
return torch.true_divide(self, other)
def historic_div_scalar_int(self, other: int):
if self.is_floating_point():
return torch.true_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return a / b
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return a / b
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int)
_helper(current_module_int, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_reciprocal(self):
def historic_div_scalar_float_reciprocal(self, other: float):
return other / self
def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point():
return other / self
return torch.divide(other, self, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return b / a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return b / a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
# so true_divide and floor_divide do not appear in their graphs
for m in (v3_module_float, v3_module_int):
self._verify_no("aten::div", m)
self._verify_no("aten::true_divide", m)
self._verify_no("aten::floor_divide", m)
self._verify_count("aten::reciprocal", m, 1)
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = None
# Reverses argument order for torch.div
if fn is torch.div:
fn_result = self._try_fn(torch.div, b, a)
else:
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(isinstance(fn_result, Exception))
elif fn is torch.div or a.is_floating_point():
self.assertEqual(m_result, fn_result)
else:
# Skip when fn is not torch.div and a is integral because
# historic_div_scalar_int performs floored division
pass
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
_helper(current_module_int, torch.div)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_inplace(self):
def historic_div_scalar_float_inplace(self, other: float):
return self.true_divide_(other)
def historic_div_scalar_int_inplace(self, other: int):
if self.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
a /= b
return a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
a /= b
return a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_inplace)
_helper(current_module_float, torch.Tensor.div_)
else:
_helper(v3_module_int, historic_div_scalar_int_inplace)
_helper(current_module_int, torch.Tensor.div_)
# NOTE: Scalar division was already true division in op version 3,
# so this test verifies the behavior is unchanged.
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_div_scalar_scalar(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a: float, b: int, c: float, d: int):
result_0 = a / b
result_1 = a / c
result_2 = b / c
result_3 = b / d
return (result_0, result_1, result_2, result_3)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 4)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 4)
def _helper(m, fn):
vals = (5., 3, 2., 7)
m_result = m(*vals)
fn_result = fn(*vals)
for mr, hr in zip(m_result, fn_result):
self.assertEqual(mr, hr)
_helper(v3_module, current_module)
# NOTE: the JIT was incapable of handling boolean fill values when
# PyTorch produced file format versions 0-4
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_full_integer_value(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, int_fill: int):
size = torch.Size(2, 2)
a = torch.full(size, int_fill)
b = torch.full(size, 1)
return (a, b)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 2)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 2)
# Verifies historic integer type inference is float
# NOTE: only verifies floating point, not exact dtype, due to
# https://github.com/pytorch/pytorch/issues/40470
results = v4_module(2)
for result in results:
self.assertTrue(result.is_floating_point())
# Verifies values are correct
a, b = results
self.assertTrue((a == 2.).all())
self.assertTrue((b == 1.).all())
# Tests that torch.full behavior which is the same from prior versions
# to version 5 is preserved.
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
@unittest.skipIf(torch._C._is_upgraders_enabled(), reason)
def test_versioned_full_preserved(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, float_fill: float):
size = (2, 2)
a = torch.full(size, 1.)
b = torch.full(size, float_fill)
c = torch.full(size, float_fill, dtype=torch.long)
out = torch.empty(size, dtype=torch.long)
d = torch.full(size, float_fill, out=out)
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
layout=torch.strided, device='cpu')
return (a, b, c, d, e)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 5)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 5)
self.assertEqual(v4_module(2.), current_module(2.))

View File

@ -24,23 +24,6 @@ if __name__ == "__main__":
)
class TestSaveLoad(JitTestCase):
def test_versioned_symbols_reserialization(self):
"""
Tests that loading and saving serialized Torchscript with a versioned
symbol won't persist the original function and will inline the
versioned builtin.
"""
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
buffer = io.BytesIO()
torch.jit.save(module_v2, buffer)
buffer.seek(0)
module_reserialized = torch.jit.load(buffer)
subcmul_nodes = sum("subcmul" in n.kind() for
n in module_reserialized.graph.nodes())
self.assertEqual(subcmul_nodes, 0)
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name

View File

@ -3,7 +3,6 @@
from itertools import product as product
import io
import os
import random
import sys
import hypothesis.strategies as st
from hypothesis import example, settings, given
@ -24,55 +23,6 @@ if __name__ == "__main__":
)
class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_symbols(self):
"""
Tests Torchscript symbol versioning. See note [Versioned Symbols].
This test uses an undocumented, test-only function
torch._test_serialization_subcmul.
This function is implemented as (a - alpha * b) with a default value
of 1 for alpha. In file format version 2, however, it was implemented
as (b - alpha * a) with a default value of 2 for alpha.
This test verifies a module seralized with file format version 2
exhibits the old behavior, and that the same module newly serialized
exhibits the current behavior.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, alpha: float):
no_alpha = torch._test_serialization_subcmul(a, b)
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
return no_alpha, with_alpha
def historic_subcmul(a, b, alpha=2):
return b - alpha * a
def current_subcmul(a, b, alpha=1):
return a - alpha * b
# Loads and verifies the historic behavior of the module
# that was serialized with version 2
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
a = torch.randn((5,))
b = torch.randn((5,))
alpha = random.random()
args = (a, b, alpha)
no_alpha_v2, with_alpha_v2 = module_v2(*args)
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
# Scripts, saves, loads and verifies the current behavior of the module
scripted_module = torch.jit.script(MyModule())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
module_current = torch.jit.load(buffer)
no_alpha_current, with_alpha_current = module_current(*args)
self.assertEqual(no_alpha_current, current_subcmul(a, b))
self.assertEqual(with_alpha_current, current_subcmul(*args))
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
scripted_module = torch.jit.script(m())
@ -107,7 +57,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
or truncated aten::divide(_) otherwise.
NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated.
"""
@ -137,18 +86,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
# Loads historic module
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 3)
for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,))
@ -164,9 +107,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
for result in m_results:
self.assertEqual(result, fn_result)
_helper(v3_module, historic_div)
_helper(v3_mobile_module, historic_div)
_helper(current_module, torch.div)
_helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@ -189,18 +130,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 1)
for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,))
@ -215,12 +150,10 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, a)
_helper(v3_module, historic_div_)
_helper(v3_mobile_module, historic_div_)
# Recreates a since it was modified in place
a = torch.tensor((val_a,))
_helper(current_module, torch.Tensor.div_)
_helper(current_mobile_module, torch.Tensor.div_)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@ -242,18 +175,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a.div(b, out=out)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 1)
for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,))
@ -274,8 +201,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, out)
_helper(v3_module, historic_div_out)
_helper(current_module, torch.div)
_helper(v3_mobile_module, historic_div_out)
_helper(current_mobile_module, torch.div)
@ -308,8 +233,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a / b
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter(
@ -321,14 +244,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,))
b = val_b
@ -343,13 +261,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float)
_helper(current_module_float, torch.div)
_helper(v3_mobile_module_float, current_mobile_module_float)
_helper(current_mobile_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int)
_helper(current_module_int, torch.div)
_helper(v3_mobile_module_int, historic_div_scalar_int)
_helper(current_mobile_module_int, torch.div)
@ -382,8 +296,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return b / a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter(
@ -391,17 +303,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
except Exception as e:
self.skipTest("Failed to load fixture!")
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
# so true_divide and floor_divide do not appear in their graphs
for m in (v3_module_float, v3_module_int):
self._verify_no("aten::div", m)
self._verify_no("aten::true_divide", m)
self._verify_no("aten::floor_divide", m)
self._verify_count("aten::reciprocal", m, 1)
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
@ -428,13 +329,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
pass
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
_helper(current_module_float, torch.div)
_helper(v3_mobile_module_float, current_mobile_module_float)
_helper(current_mobile_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
_helper(current_module_int, torch.div)
_helper(v3_mobile_module_int, current_mobile_module_int)
_helper(current_mobile_module_int, torch.div)
@ -470,9 +367,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
return a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
v3_mobile_module_int = _load_for_lite_interpreter(
@ -480,22 +374,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
current_mobile_module_float = self._save_load_module(MyModuleFloat)
current_mobile_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for val_a, val_b in product(sample_input, sample_input):
a = torch.tensor((val_a,))
b = val_b
@ -510,12 +391,8 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_inplace)
_helper(current_module_float, torch.Tensor.div_)
_helper(current_mobile_module_float, torch.Tensor.div_)
else:
_helper(v3_module_int, historic_div_scalar_int_inplace)
_helper(current_module_int, torch.Tensor.div_)
_helper(current_mobile_module_int, torch.Tensor.div_)
# NOTE: Scalar division was already true division in op version 3,
@ -533,17 +410,12 @@ class TestSaveLoadForOpVersion(JitTestCase):
return (result_0, result_1, result_2, result_3)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 4)
current_module = self._save_load_module(MyModule)
current_mobile_module = self._save_load_mobile_module(MyModule)
self._verify_count("aten::div", current_module, 4)
def _helper(m, fn):
vals = (5., 3, 2., 7)
@ -552,74 +424,4 @@ class TestSaveLoadForOpVersion(JitTestCase):
for mr, hr in zip(m_result, fn_result):
self.assertEqual(mr, hr)
_helper(v3_module, current_module)
_helper(v3_mobile_module, current_mobile_module)
# NOTE: the JIT was incapable of handling boolean fill values when
# PyTorch produced file format versions 0-4
def test_versioned_full_integer_value(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, int_fill: int):
size = torch.Size(2, 2)
a = torch.full(size, int_fill)
b = torch.full(size, 1)
return (a, b)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 2)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 2)
# Verifies historic integer type inference is float
# NOTE: only verifies floating point, not exact dtype, due to
# https://github.com/pytorch/pytorch/issues/40470
results = v4_module(2)
for result in results:
self.assertTrue(result.is_floating_point())
# Verifies values are correct
a, b = results
self.assertTrue((a == 2.).all())
self.assertTrue((b == 1.).all())
# Tests that torch.full behavior which is the same from prior versions
# to version 5 is preserved.
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
def test_versioned_full_preserved(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, float_fill: float):
size = (2, 2)
a = torch.full(size, 1.)
b = torch.full(size, float_fill)
c = torch.full(size, float_fill, dtype=torch.long)
out = torch.empty(size, dtype=torch.long)
d = torch.full(size, float_fill, out=out)
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
layout=torch.strided, device='cpu')
return (a, b, c, d, e)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 5)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 5)
self.assertEqual(v4_module(2.), current_module(2.))

View File

@ -4,6 +4,11 @@ import io
import os
import sys
import torch
import unittest
import zipfile
from torch.testing import FileCheck
from torch._C import _is_upgraders_enabled
from typing import Union
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -16,6 +21,15 @@ if __name__ == '__main__':
"instead.")
class TestUpgraders(JitTestCase):
def _load_model_version(self, loaded_model):
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
version = int(zipped_model.read('archive/version').decode("utf-8"))
return version
# TODO (tugsuu) We should ideally be generating this test cases.
def test_populated_upgrader_graph(self):
@torch.jit.script
def f():
@ -67,7 +81,7 @@ class TestUpgraders(JitTestCase):
# upgrader map should have populated now
upgraders_size = torch._C._get_upgraders_map_size()
test_map = {"a": "b", "c": "d"}
test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
torch._C._test_only_populate_upgraders(test_map)
upgraders_size_after_test = torch._C._get_upgraders_map_size()
self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
@ -81,3 +95,117 @@ class TestUpgraders(JitTestCase):
upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
self.assertTrue("a" not in upgraders_dump_after_remove_test)
self.assertTrue("c" not in upgraders_dump_after_remove_test)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_tensor_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
loaded_model = torch.jit.load(model_path)
# there are 3 aten::div in this model
# And the upgrader for aten::div uses two
# div's because of if/else branch
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
# we check by its code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_test_serialization(self):
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
# add test version entry to the version map
upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
# add test upgrader in the upgraders map
@torch.jit.script
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
return other - (self * alpha)
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
# test if the server is able to find the test upgraders and apply to IR
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 3)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_scalar_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_tensor_out_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 4)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
buffer = io.BytesIO()
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 5)
loaded_model_twice = torch.jit.load(buffer)
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_out_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 5)

View File

@ -68,6 +68,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401
from jit.test_aten_pow import TestAtenPow # noqa: F401
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
from jit.test_union import TestUnion # noqa: F401
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
from jit.test_models import MnistNet
from jit.test_batch_mm import TestBatchMM # noqa: F401
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401

View File

@ -112,6 +112,7 @@ core_sources_common = [
"torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/mobile/runtime_compatibility.cpp",
"torch/csrc/jit/operator_upgraders/version_map.cpp",
"torch/csrc/jit/operator_upgraders/upgraders_guard.cpp",
"torch/csrc/jit/runtime/instruction.cpp",
"torch/csrc/jit/runtime/jit_exception.cpp",
"torch/csrc/jit/runtime/operator.cpp",
@ -121,7 +122,6 @@ core_sources_common = [
"torch/csrc/jit/runtime/vararg_functions.cpp",
"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
"torch/csrc/jit/mobile/prim_ops_registery.cpp",
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
"torch/csrc/profiler/util.cpp",
]
@ -209,6 +209,8 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/mobile/nnc/context.cpp",
"torch/csrc/jit/mobile/nnc/registry.cpp",
"torch/csrc/jit/operator_upgraders/utils.cpp",
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
"torch/csrc/jit/passes/annotate_warns.cpp",
"torch/csrc/jit/passes/bailout_graph.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",

View File

@ -293,7 +293,7 @@ def _create_function_from_trace(
def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
def _populate_upgraders_map(content: Dict[str, str]) -> None: ...
def _is_upgraders_enabled() -> _bool: ...
def _get_upgraders_map_size() -> _int: ...
def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/frontend/builtin_functions.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/frontend/resolver.h>
@ -85,6 +86,7 @@ def __contains__(self: str, key: str):
return self.find(key, 0, len(self)) != -1
)SCRIPT";
#if !ENABLE_UPGRADERS
// Implementations of historic symbol behaviors are defined here
// See note [Versioned Symbols]
@ -158,7 +160,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
pin_memory:Optional[bool]=None) -> Tensor:
if dtype is None:
fill_value = float(fill_value)
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
)SCRIPT";
@ -168,6 +169,7 @@ auto full_out = R"SCRIPT(
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
return torch.full(size, fill_value, out=out)
)SCRIPT";
#endif
struct BuiltinFunctionRegistry {
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
@ -237,6 +239,7 @@ struct BuiltinFunctionRegistry {
loadSource(aten_ops, "aten");
loadSource(aten_ops_additional, "aten");
#if !ENABLE_UPGRADERS
// Loads functions implementing historic behavior, see note [Versioned
// Symbols]
// Note: these functions go into the "upgraders" namespace
@ -249,6 +252,7 @@ struct BuiltinFunctionRegistry {
loadSource(div__scalar, "upgraders");
loadSource(full, "upgraders");
loadSource(full_out, "upgraders");
#endif
// These are under `prim` instead of `aten` since they exist to bind certain
// tensor property getters to correpsonding methods

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
@ -22,6 +23,7 @@
#include <torch/csrc/jit/passes/lift_closures.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/normalize_ops.h>
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
@ -656,6 +658,13 @@ struct to_ir {
}
method.setSchema(emitDef(def, self, graph->block()));
#if ENABLE_UPGRADERS
// At this point, we might have received a graph that is compiled with
// old operator schemas that might not exist in the system anymore.
// Therefore, we replace such ops with its' valid upgrader.
ReplaceOldOperatorsWithUpgraders(graph);
#endif
// NB ORDERING: SSA conversion has to occur before
// lifting of closures and forks, this way closures are converted
// to SSA while part of their original graph, and closures are ready to

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/frontend/builtin_functions.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
@ -469,7 +470,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
}
// construct the full name of the schema for easier look up
auto schema_name = schema.operator_name().name + "." + schema.overload_name();
auto schema_name = getFullSchemaName(schema);
return MatchedSchema{
std::move(positional_inputs),
@ -619,6 +620,13 @@ static Value* emitBuiltinNode(
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
}
std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
if (schema.overload_name() != "") {
return schema.operator_name().name + "." + schema.overload_name();
}
return schema.operator_name().name;
}
// Search for operators matching the provided symbol name and input types.
// If one is found, emit a node to the graph for that operator.
Value* emitBuiltinCall(
@ -631,8 +639,12 @@ Value* emitBuiltinCall(
const auto& variants = getAllOperatorsFor(name);
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
#if ENABLE_UPGRADERS
// first let's set the graph's version
auto graph_version = graph.get_op_version();
#else
c10::optional<size_t> graph_version = c10::nullopt;
#endif
std::stringstream failure_messages;
std::vector<const FunctionSchema*> schemas;
@ -643,9 +655,7 @@ Value* emitBuiltinCall(
schemas.reserve(variants.size());
for (const std::shared_ptr<Operator>& op : variants) {
bool found_upgrader = false;
auto op_overload_name = op->schema().overload_name();
auto op_name = op->schema().operator_name().name +
((op_overload_name != "") ? "." + op_overload_name : "");
auto op_name = getFullSchemaName(op->schema());
if (graph_version.has_value()) {
auto version_entry = get_operator_version_map().find(op_name);
if (version_entry != get_operator_version_map().end()) {

View File

@ -41,6 +41,8 @@ TORCH_API bool convertibleToList(
const TypePtr& type,
const TypePtr& list_type_);
TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema);
TORCH_API Value* emitBuiltinCall(
const SourceRange& loc,
Graph& graph,

View File

@ -5,6 +5,7 @@
#include <utility>
#include <ATen/core/symbol.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
@ -319,12 +320,14 @@ struct TORCH_API BuiltinModule : public SugaredValue {
}
auto sym = Symbol::fromQualString(name + "::" + field);
#if !ENABLE_UPGRADERS
if (version.has_value()) {
// Possibly replaces symbol with another that implements its
// historic behavior.
// See note [Versioned Symbols]
sym = get_symbol_for_version(sym, *version);
}
#endif
return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
}

View File

@ -1,12 +1,13 @@
#include <torch/csrc/jit/frontend/versioned_symbols.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/api/include/torch/jit.h>
#include <unordered_map>
namespace torch {
namespace jit {
#if !ENABLE_UPGRADERS
// Note [Versioned Symbols]
// When the schema or behavior of a symbol changes, serialized Torchscript
// programs using that symbol are likely to break. To prevent those breaks,
@ -105,6 +106,7 @@ uint64_t get_min_version_for_kind(const NodeKind& kind) {
return it->second;
}
#endif
} // namespace jit
} // namespace torch

View File

@ -1,5 +1,6 @@
#pragma once
#include <caffe2/serialize/versions.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/api/module.h>
@ -7,7 +8,7 @@
namespace torch {
namespace jit {
#if !ENABLE_UPGRADERS
// Maps the given symbol into an implementation of its behavior at the
// given version.
// See note [Versioned Symbols]
@ -17,6 +18,6 @@ get_symbol_for_version(const Symbol name, const uint64_t version);
// Maps the given kind to the minimum version that supports it.
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
#endif
} // namespace jit
} // namespace torch

View File

@ -1,4 +1,7 @@
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <mutex>
#include <string>
#include <unordered_map>
@ -9,12 +12,13 @@ namespace jit {
static UpgradersMap upgradersMap;
void UpgradersMap::set_content(
std::unordered_map<std::string, std::string>&& content) {
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
// make sure we populate the map only once
std::lock_guard<std::mutex> _(lock);
if (isPopulated) {
return;
}
content_ = std::move(content);
isPopulated = true;
}
@ -24,7 +28,12 @@ int UpgradersMap::count() {
return content_.size();
}
const std::unordered_map<std::string, std::string>& UpgradersMap::
bool UpgradersMap::is_populated() {
std::lock_guard<std::mutex> _(lock);
return isPopulated;
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
get_content() {
std::lock_guard<std::mutex> _(lock);
return content_;
@ -34,7 +43,9 @@ void UpgradersMap::test_only_set_content(
const std::unordered_map<std::string, std::string>& content) {
std::lock_guard<std::mutex> _(lock);
for (const auto& entry : content) {
content_.insert(entry);
auto graph = std::make_shared<Graph>();
torch::jit::parseIR(entry.second, graph.get());
content_.insert(std::make_pair(entry.first, graph));
}
}
void UpgradersMap::test_only_remove_content(
@ -46,7 +57,7 @@ void UpgradersMap::test_only_remove_content(
}
void populate_upgraders_map(
std::unordered_map<std::string, std::string>&& content) {
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
upgradersMap.set_content(std::move(content));
}
@ -54,7 +65,12 @@ int get_upgraders_map_size() {
return upgradersMap.count();
}
const std::unordered_map<std::string, std::string>& dump_upgraders_map() {
bool is_upgraders_map_populated() {
return upgradersMap.is_populated();
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map() {
return upgradersMap.get_content();
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/macros/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <iostream>
#include <mutex>
#include <string>
@ -10,9 +11,11 @@ namespace jit {
class UpgradersMap {
public:
void set_content(std::unordered_map<std::string, std::string>&& content);
void set_content(
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
int count();
const std::unordered_map<std::string, std::string>& get_content();
const std::unordered_map<std::string, std::shared_ptr<Graph>>& get_content();
bool is_populated();
// THESE METHODS ARE ONLY USED FOR TESTING PURPOSES
void test_only_set_content(
const std::unordered_map<std::string, std::string>& content);
@ -20,17 +23,19 @@ class UpgradersMap {
const std::unordered_map<std::string, std::string>& content);
private:
std::unordered_map<std::string, std::string> content_;
std::unordered_map<std::string, std::shared_ptr<Graph>> content_;
std::mutex lock;
bool isPopulated = false;
};
TORCH_API void populate_upgraders_map(
std::unordered_map<std::string, std::string>&& content);
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content);
TORCH_API int get_upgraders_map_size();
TORCH_API const std::unordered_map<std::string, std::string>&
TORCH_API bool is_upgraders_map_populated();
TORCH_API const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map();
// THESE TWO METHODS BELOW ARE ONLY USED FOR TESTING

View File

@ -0,0 +1,78 @@
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <ATen/core/stack.h>
#include <c10/macros/Export.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
static std::unordered_map<std::string, std::string> kUpgradersEntryMap(
{{"div_Tensor_0_3", R"SCRIPT(
def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_Scalar_0_3", R"SCRIPT(
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_out_0_3", R"SCRIPT(
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
return self.true_divide(other, out=out)
return self.divide(other, rounding_mode='trunc', out=out)
)SCRIPT"},
{"div__Tensor_0_3", R"SCRIPT(
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"div__Scalar_0_3", R"SCRIPT(
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"full_0_4", R"SCRIPT(
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
layout:Optional[int]=None, device:Optional[Device]=None,
pin_memory:Optional[bool]=None) -> Tensor:
if dtype is None:
fill_value = float(fill_value)
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
)SCRIPT"},
{"full_out_0_4", R"SCRIPT(
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
return torch.full(size, fill_value, out=out)
)SCRIPT"}});
using UpgraderMap = std::unordered_map<std::string, std::shared_ptr<Graph>>;
void populate_upgraders_graph_map() {
if (!is_upgraders_map_populated()) {
UpgraderMap populate_content;
for (const auto& entry : kUpgradersEntryMap) {
auto cu = std::make_shared<CompilationUnit>();
cu->define(c10::nullopt, entry.second, nativeResolver(), nullptr);
Function& jitFunc = cu->get_function(entry.first);
GraphFunction& graphFunction = toGraphFunction(jitFunc);
populate_content.insert(
std::make_pair(entry.first, graphFunction.graph()));
}
populate_upgraders_map(std::forward<UpgraderMap>(populate_content));
}
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,14 @@
#pragma once
#include <c10/macros/Export.h>
#include <iostream>
#include <mutex>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
TORCH_API void populate_upgraders_graph_map();
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,12 @@
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
namespace torch {
namespace jit {
bool is_upgraders_enabled() {
return ENABLE_UPGRADERS;
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,10 @@
#pragma once
#include <c10/macros/Export.h>
namespace torch {
namespace jit {
TORCH_API bool is_upgraders_enabled();
} // namespace jit
} // namespace torch

View File

@ -31,11 +31,11 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
{"aten::div_.Tensor",
{{4,
"div__Tensor_0_3",
"aten::div_.Tensor(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
{"aten::div_.Scalar",
{{4,
"div__Scalar_0_3",
"aten::div_.Scalar(Tensor(a!), Tensor other) -> Tensor(a!)"}}},
"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
{"aten::full",
{{5,
"full_0_4",

View File

@ -15,29 +15,6 @@
namespace torch {
namespace jit {
static std::unordered_map<std::string, std::shared_ptr<Graph>> upgraderCache;
std::shared_ptr<Graph> getUpgraderGraph(const std::string& upgrader_name) {
auto it = upgraderCache.find(upgrader_name);
if (it != upgraderCache.end()) {
return it->second;
}
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
TORCH_INTERNAL_ASSERT(
upgrader_graph_entry != dump_upgraders_map().end(),
"Corresponding upgrader graph for ",
upgrader_name,
" must exist.",
" This upgrader"
" might be deprecated.");
auto upgrader_graph = std::make_shared<Graph>();
parseIR(upgrader_graph_entry->second, upgrader_graph.get());
upgraderCache[upgrader_name] = upgrader_graph;
return upgrader_graph;
}
struct OldOpsReplacerWithUpgraders {
OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}
@ -52,9 +29,7 @@ struct OldOpsReplacerWithUpgraders {
Node* node = graph_it.next();
while (node) {
if (auto schema = node->maybeSchema()) {
auto schema_name = schema->name() +
(schema->overload_name() != "" ? "." + schema->overload_name()
: "");
auto schema_name = getFullSchemaName(*schema);
// this implies there was a version bump because of this operator
auto version_entry = get_operator_version_map().find(schema_name);
if (version_entry != get_operator_version_map().end()) {
@ -74,8 +49,16 @@ struct OldOpsReplacerWithUpgraders {
}
auto upgrader_entry_val = upgrader_entry.value();
auto upgrader_name = upgrader_entry_val.upgrader_name;
auto upgrader_graph = getUpgraderGraph(upgrader_name);
auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
TORCH_INTERNAL_ASSERT(
upgrader_graph_entry != dump_upgraders_map().end(),
"Corresponding upgrader graph for ",
upgrader_name,
" must exist.",
" This upgrader"
" might be deprecated.");
auto upgrader_graph = upgrader_graph_entry->second;
// inline the upgrader function body
WithInsertPoint guard(node);
auto new_outputs = insertGraph(

View File

@ -14,6 +14,7 @@
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_guard.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
@ -1731,7 +1732,8 @@ void initJitScriptBindings(PyObject* module) {
return Decl(p.parseTypeComment());
});
m.def("_populate_upgraders_map", &populate_upgraders_map);
m.def("_is_upgraders_enabled", &is_upgraders_enabled);
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
m.def("_dump_upgraders_map", &dump_upgraders_map);

View File

@ -10,6 +10,7 @@
#endif
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <torch/csrc/jit/serialization/import_source.h>
@ -20,6 +21,7 @@
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/versions.h>
#include <ATen/ATen.h>
#include <fmt/format.h>
@ -239,6 +241,10 @@ graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_ze
Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
// we populate the upgraders map before any load starts
#if ENABLE_UPGRADERS
populate_upgraders_graph_map();
#endif
C10_LOG_API_USAGE_ONCE("torch.script.load");
device_ = device;
// Load extra files.

View File

@ -577,7 +577,9 @@ void SourceImporterImpl::importClass(
/*propResolvers=*/{},
methods,
method_resolvers,
&self);
&self,
/*shouldMangle=*/false,
/*operator_set_version=*/version_);
cu_->define_hooks(
qualified_classname,
hooks,

View File

@ -6,6 +6,7 @@
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
@ -13,6 +14,7 @@
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
@ -741,9 +743,22 @@ struct PythonPrintImpl {
}
}
void checkVersion(const Node* const node) {
void checkVersion(Node* node) {
#if ENABLE_UPGRADERS
if (auto schema = node->maybeSchema()) {
auto schema_name = getFullSchemaName(*schema);
auto version_entry = get_operator_version_map().find(schema_name);
if (version_entry != get_operator_version_map().end()) {
const auto& entry = version_entry->second;
// TODO (tugsuu) move this calculation into a seperate step.
min_version_ = std::max(
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
}
}
#else
min_version_ =
std::max(min_version_, get_min_version_for_kind(node->kind()));
#endif
}
void printNode(Node* node, bool print_const) {
@ -1594,7 +1609,11 @@ struct PythonPrintImpl {
bool enforce_importable_;
// The least version that supports all printed ops
#if ENABLE_UPGRADERS
uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
#else
uint64_t min_version_ = 0;
#endif
};
PythonPrint::PythonPrint(

View File

@ -147,11 +147,6 @@ def load(f, map_location=None, _extra_files=None):
os.remove("scriptmodule.pt")
"""
# TODO (tugsuu) Putting this on top level creates
# circular import issue. We need to fix this later
from torch.jit.operator_upgraders import populate_upgraders_map
populate_upgraders_map()
if isinstance(f, string_classes):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]

View File

@ -116,12 +116,5 @@ def generate_bytecode() -> List:
yaml_content.append(entry)
return yaml_content
def populate_upgraders_map():
upgrader_set = collect_available_upgraders()
content = {}
for upgrader_name in upgrader_set:
content[upgrader_name] = str(globals()[upgrader_name].graph)
torch._C._populate_upgraders_map(content)
if __name__ == "__main__":
raise RuntimeError("This file is not meant to be run directly")