mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42413 When a default argument is added, it does not break backward compatibility (BC) for full-jit, but does break BC for mobile bytecode. For example, https://github.com/pytorch/pytorch/pull/40737. To make bytecode BC in this case, we 1. Introduce kMinSupportedBytecodeVersion. The loaded model version should be between kMinSupportedBytecodeVersion and kProducedBytecodeVersion. 2. If an operator is updated, and we can handle BC, bump the kProducedBytecodeVersion (for example, from 3 to 4). 3. If model version is at the older version of the operator, add an adapter function at loading. For the added default arg, we push this default arg to stack before calling the actual operator function. Test Plan: Imported from OSS Reviewed By: xcheng16 Differential Revision: D22898314 Pulled By: iseeyuan fbshipit-source-id: 90d339f8e1365f4bb178db8db7c147390173372b
110 lines
3.1 KiB
C++
110 lines
3.1 KiB
C++
#include <torch/csrc/jit/mobile/function.h>
|
|
#include <caffe2/serialize/inline_container.h>
|
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/custom_class_detail.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
char const* toString(OpCode op);
|
|
namespace mobile {
|
|
Function::Function(c10::QualifiedName name)
|
|
: name_(name), code_(std::make_shared<Code>()) {}
|
|
|
|
void Function::append_instruction(OpCode op, int X, int N) {
|
|
TORCH_CHECK(
|
|
op != CREATE_OBJECT,
|
|
"CREATE_OBJECT is not supported in mobile module. ",
|
|
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
|
"define a pytorch class (class Foo(torch.nn.Module)).");
|
|
TORCH_CHECK(
|
|
isOpSupportedInMobile(op),
|
|
toString(op),
|
|
" is not supported in mobile module.");
|
|
code_->instructions_.emplace_back(op, X, N);
|
|
}
|
|
|
|
bool Function::append_operator(
|
|
const std::string& name,
|
|
const std::string& overload_name,
|
|
int64_t model_version) {
|
|
// Keep the original opname in code_
|
|
code_->op_names_.emplace_back(name, overload_name);
|
|
auto opname = code_->op_names_.back();
|
|
|
|
auto opname_c10 = opname;
|
|
std::function<void(Stack&)> fn;
|
|
|
|
auto jit_op = findOperatorFor(opname);
|
|
if (jit_op) {
|
|
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
|
|
} else {
|
|
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
|
|
if (op.has_value()) {
|
|
fn = [op](Stack& stack) { op->callBoxed(&stack); };
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (model_version == 0x3L &&
|
|
model_version < caffe2::serialize::kProducedBytecodeVersion &&
|
|
opname == c10::OperatorName("aten::_convolution", "")) {
|
|
// A default-value argument will be added in
|
|
// https://github.com/pytorch/pytorch/pull/40737. This wrapper is used to
|
|
// handle backward compatibility, where there is no default bool value in
|
|
// old models.
|
|
fn = [fn](Stack& stack) {
|
|
stack.push_back(true);
|
|
fn(stack);
|
|
};
|
|
}
|
|
|
|
code_->operators_.emplace_back(fn);
|
|
return true;
|
|
}
|
|
|
|
void Function::set_module_debug_info_list_size(size_t size) {
|
|
pc_to_module_debug_info_.resize(size);
|
|
for (size_t i = 0; i < size; ++i) {
|
|
pc_to_module_debug_info_[i] = "<no module info>";
|
|
}
|
|
}
|
|
|
|
void Function::set_module_info(const std::string& module_info, size_t pc) {
|
|
TORCH_CHECK(
|
|
pc < pc_to_module_debug_info_.size(),
|
|
"Module debug info index out of boundary.");
|
|
pc_to_module_debug_info_[pc] = module_info;
|
|
}
|
|
|
|
void Function::append_constant(const c10::IValue& constant) {
|
|
code_->constants_.push_back(constant);
|
|
}
|
|
|
|
void Function::append_type(const at::TypePtr& type) {
|
|
code_->types_.push_back(type);
|
|
}
|
|
|
|
void Function::set_register_size(size_t size) {
|
|
code_->register_size_ = size;
|
|
}
|
|
|
|
std::string Function::get_module_debug_info(size_t pc) const {
|
|
TORCH_CHECK(
|
|
pc < pc_to_module_debug_info_.size(),
|
|
"Module debug info index out of boundary.");
|
|
return pc_to_module_debug_info_[pc];
|
|
}
|
|
|
|
bool Function::run(Stack& stack) const {
|
|
InterpreterState interp_state(code_);
|
|
return interp_state.run(stack);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|