mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: PyTorch users write programs and save them as serialized Torchscript. When this Torchscript is loaded it contains symbols like "aten::div" describing some of the program's behavior. If the behavior of these symbols has changed since the program was serialized, however, then the original program's semantics may not be preserved. For example, when we make aten::div always perform "true" division, like NumPy, Python3, and JAX, then serialized Torchscript programs relying on aten::div performing floor division on integral inputs will break. This PR demonstrates the "Versioned Symbol" pattern that lets symbols be remapped into Torchscript builtins that preserve their historic behavior. Using this pattern, after we update aten::div to always perform true division, we could remap it in older Torchscript programs to a builtin that implements its historic behavior. The pattern is described in the [Versioned Symbols] note in the code and is implemented like this: - If BuiltinModule is given a version, before it returns a symbol it queries to see if another symbol should be substituted for it. - versioned_symbol.cpp has a map for symbols and the version range for which another symbol should be substituted for them. - The substitutions are implemented as builtin functions. An example using the new, test-only _subcmul function is implemented to test this behavior. A test in jit/test_save_load.py follows the pattern described in the [Versioned Symbols] note and uses a fixture serialized with file version 2 to verify that the historic behavior is preserved. In the future we will likely need a slightly more complex mechanism with multiple substitutions with distinct version ranges, and this just requires changing the map to be Symbol->Iterable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/36300 Differential Revision: D21058990 Pulled By: mruberry fbshipit-source-id: 2b7c732878c0ecfcd9f0a6205fb6d6421feeaf61
160 lines
5.1 KiB
C++
160 lines
5.1 KiB
C++
#include <torch/csrc/jit/frontend/builtin_functions.h>
|
|
#include <torch/csrc/api/include/torch/jit.h>
|
|
#include <torch/csrc/jit/frontend/code_template.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
auto scalar_operators_source = CodeTemplate(
|
|
R"SCRIPT(
|
|
def mul(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b * a
|
|
def add(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b + a
|
|
def ne(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b != a
|
|
def eq(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b == a
|
|
def lt(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b > a
|
|
def le(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b >= a
|
|
def gt(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b < a
|
|
def ge(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return b <= a
|
|
def sub(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return torch.neg(b) + a
|
|
def div(a : ${Scalar}, b : Tensor) -> Tensor:
|
|
return torch.reciprocal(b) * a
|
|
)SCRIPT");
|
|
|
|
auto _ntuple_ops = CodeTemplate(
|
|
R"SCRIPT(
|
|
def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
|
|
return x
|
|
)SCRIPT");
|
|
|
|
auto floordiv = CodeTemplate(
|
|
R"SCRIPT(
|
|
def floordiv(self : Tensor, other : ${Rhs_Type}) -> Tensor:
|
|
return torch.floor_divide(self, other)
|
|
)SCRIPT");
|
|
|
|
auto tensor_properties =
|
|
R"SCRIPT(
|
|
def ndim(a : Tensor) -> int:
|
|
return a.dim()
|
|
def T(a : Tensor) -> Tensor:
|
|
return a.numpy_T()
|
|
def shape(a : Tensor) -> List[int]:
|
|
return a.size()
|
|
)SCRIPT";
|
|
|
|
// This is only here for backwards-compatibility with the
|
|
// aten::_assert_int_or_pair op which was removed once we were able to compile
|
|
// torch.nn.functional.assert_int_or_pair
|
|
auto aten_ops =
|
|
R"SCRIPT(
|
|
def _assert_int_or_pair(vals: List[int], name: str, message: str):
|
|
pass
|
|
)SCRIPT";
|
|
|
|
// Implementations of historic symbol behaviors are defined here
|
|
// See note [Versioned Symbols]
|
|
auto _test_serialization_subcmul = R"SCRIPT(
|
|
def _test_serialization_subcmul_0_2(self: Tensor, other:Tensor, alpha: number=2) -> Tensor:
|
|
return other - (self * alpha)
|
|
)SCRIPT";
|
|
|
|
struct BuiltinFunctionRegistry {
|
|
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
|
const static std::vector<Function*> empty;
|
|
// when initializing the builtin function library, we will re-enter
|
|
// getAllBuiltinFunctionsFor since it is called in the compiler to
|
|
// lookup builtins and initializing the builtin functions calls the
|
|
// compiler. To avoid deadlocking, we use a recursive mutex (same thread can
|
|
// re-lock, the mutex without waiting), and report no loaded builtins during
|
|
// init.
|
|
std::lock_guard<std::recursive_mutex> guard(mutex);
|
|
if (state == INTIIALIZING) {
|
|
return empty;
|
|
} else if (state == UNINITIALIZED) {
|
|
state = INTIIALIZING;
|
|
loadBuiltinFunctions();
|
|
state = INITIALIZED;
|
|
}
|
|
AT_ASSERT(state == INITIALIZED);
|
|
auto it = builtins_by_name_.find(name);
|
|
if (it == builtins_by_name_.end())
|
|
return empty;
|
|
return it->second;
|
|
}
|
|
|
|
private:
|
|
void loadSource(const std::string& source, const std::string& the_namespace) {
|
|
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
|
|
modules.emplace_back(cu);
|
|
cu->define(c10::nullopt, source, nativeResolver(), /*self=*/nullptr);
|
|
for (auto& method : cu->get_functions()) {
|
|
builtins_by_name_[Symbol::fromQualString(
|
|
the_namespace + "::" + method->name())]
|
|
.push_back(method);
|
|
}
|
|
}
|
|
|
|
void loadBuiltinFunctions() {
|
|
for (auto scalar : {"float", "int"}) {
|
|
TemplateEnv env;
|
|
env.s("Scalar", scalar);
|
|
loadSource(scalar_operators_source.format(env), "aten");
|
|
}
|
|
|
|
using str_pair = std::pair<std::string, std::string>;
|
|
const std::vector<str_pair> name_len = {
|
|
str_pair("single", "1"),
|
|
str_pair("pair", "2"),
|
|
str_pair("triple", "3"),
|
|
str_pair("quadruple", "4"),
|
|
};
|
|
for (const auto scalar : {"float", "int"}) {
|
|
for (const auto& pair : name_len) {
|
|
TemplateEnv env;
|
|
env.s("Scalar", scalar);
|
|
env.s("name", pair.first);
|
|
env.s("Length", pair.second);
|
|
loadSource(_ntuple_ops.format(env), "aten");
|
|
}
|
|
}
|
|
for (auto rhs : {"number", "Tensor"}) {
|
|
TemplateEnv env;
|
|
env.s("Rhs_Type", rhs);
|
|
loadSource(floordiv.format(env), "aten");
|
|
}
|
|
|
|
loadSource(aten_ops, "aten");
|
|
|
|
// Loads functions implementing historic behavior, see note [Versioned
|
|
// Symbols]
|
|
// Note: these functions go into the "upgraders" namespace
|
|
loadSource(_test_serialization_subcmul, "upgraders");
|
|
|
|
// These are under `prim` instead of `aten` since they exist to bind certain
|
|
// tensor property getters to correpsonding methods
|
|
loadSource(tensor_properties, "prim");
|
|
}
|
|
enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
|
|
std::recursive_mutex mutex;
|
|
std::vector<std::shared_ptr<CompilationUnit>> modules;
|
|
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name_;
|
|
};
|
|
|
|
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
|
static BuiltinFunctionRegistry registry;
|
|
return registry.getAllBuiltinFunctionsFor(name);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|