mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56845 Handle forward/backward compatibility caused by added default arguments in mobile. As an example, In older version, operator aten::foo's schema is ``` foo(Tensor a, Tensor b) -> Tensor ``` In the new version, the schema is updated to ``` foo(Tensor a, Tensor b, int groups=1) -> Tensor ``` ## Model file Serialize the number of specified arguments to each operator into the bytecode operator table. Before the operator table contains operator name and overload name: ``` ('operators', (('aten::foo', ''),)) ``` Now the number of specified arguments is added: ``` # bytecode version 6 ('operators', (('aten::foo', '', 2),)) ``` where "2" means the number of specified arguments. Since there's bytecode schema change, the bytecode version number is bumped. This PR is to be landed after #56002 , where the version number is bumped from 4 to 5. This PR bumps the version number from 5 to 6. ## Runtime and backward compatibility When the operator is found (either jit or c10), we have the OperatorHandle, where the operator schema can be accessed by ``` op.value().schema().arguments() ``` Adaptation is implemented to handle backward compatibility. For the example above, the new runtime holds the updated schema: ``` foo(Tensor a, Tensor b, int groups=1) -> Tensor ``` Whereas the model file carries ``` (('aten::foo', ''), 2) ``` We can implement a wrapper around the original function pointer to push the default argument to the stack. ## Deliver time and forward compatibility At model delivery time, two checks can be done: ### Operator check Two APIs to be provided: * Runtime: An API to get a runtime’s ops and their schemas (i.e. the # of args). D27920185(WIP) * Model: An API to get a model’s ops and their schema requirements (i.e. the # of args required). The APIs can be used to check * runtime.ops() is a superset of model.ops() * for each op in model.ops() validate their schemas are compatible with those in runtime.ops() -- i.e. the # args required in a model op are <= # args in the runtime op. Note that only root ops in the model needs to be checked here. For transient ops it's not necessary. For example, if a root op, "aten::root" calls "aten::foo", it's "aten::root"'s responsibility to adapt to "aten::foo"'s change, or "aten::root" itself needs to be updated too. ### Bytecode version backport (PR coming) When delivering a model with bytecode v6, if the runtime only works with bytecode v5 and lower, backport is needed. * The number of arguments is removed from the operator table * The bytecode version is changed from 6 to 5 Note that this backport is a pure format change, it does not guarantee the backported model always runs in old runtime. The operator check mentioned before should be done first, before it’s back ported to v5. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D27986544 Pulled By: iseeyuan fbshipit-source-id: 143e19d4798cfb96b65095538dd648eead4e3fda
156 lines
4.8 KiB
C++
156 lines
4.8 KiB
C++
#include <torch/csrc/jit/mobile/function.h>
|
|
|
|
#include <caffe2/serialize/inline_container.h>
|
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/custom_class_detail.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
char const* toString(OpCode op);
|
|
namespace mobile {
|
|
Function::Function(c10::QualifiedName name)
|
|
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
|
|
|
const c10::QualifiedName& Function::qualname() const {
|
|
return name_;
|
|
}
|
|
|
|
const std::string& Function::name() const {
|
|
return name_.name();
|
|
}
|
|
|
|
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
|
|
TORCH_CHECK(
|
|
isOpSupportedInMobile(op),
|
|
toString(op),
|
|
" is not supported in mobile module.");
|
|
code_->instructions_.emplace_back(op, X, N);
|
|
code_->debug_handles_.emplace_back(dbg_handle);
|
|
}
|
|
|
|
bool Function::
|
|
append_operator(const std::string& name, const std::string& overload_name, const c10::optional<int>& num_specified_args, int64_t model_version /* TODO: T90339189 deprecate all v3 when v3 models are removed */) {
|
|
// Keep the original opname in code_
|
|
code_->op_names_.emplace_back(name, overload_name);
|
|
auto opname = code_->op_names_.back();
|
|
|
|
const auto& opname_c10 = opname;
|
|
std::function<void(Stack&)> fn;
|
|
|
|
auto jit_op = findOperatorFor(opname);
|
|
std::vector<c10::Argument> args;
|
|
if (jit_op) {
|
|
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
|
|
args = jit_op->schema().arguments();
|
|
} else {
|
|
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
|
|
if (op.has_value()) {
|
|
fn = [op](Stack& stack) { op->callBoxed(&stack); };
|
|
if (op->hasSchema()) {
|
|
args = op->schema().arguments();
|
|
} else {
|
|
TORCH_CHECK(false, "arguments are missing for operator ", opname);
|
|
}
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (model_version == 0x3LL &&
|
|
opname == c10::OperatorName("aten::_convolution", "")) {
|
|
// Since byte-code versions 0x4L, convolution has an additional
|
|
// default-value argument (allow_tf32=True, see
|
|
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
|
|
// backward compatibility with models of byte-code version <= 0x3L, where
|
|
// this bool argument does not yet exist.
|
|
fn = [fn](Stack& stack) {
|
|
stack.push_back(true);
|
|
fn(stack);
|
|
};
|
|
} else {
|
|
// num_specified_args >= 0 indicates number of arguments are available
|
|
// from model. We can use it to handle backward compatibility.
|
|
if (num_specified_args &&
|
|
num_specified_args.value() < static_cast<int64_t>(args.size())) {
|
|
// Sanity check at load time, to save perf at runtime
|
|
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
|
|
auto default_val = args[i].default_value();
|
|
TORCH_CHECK(
|
|
default_val.has_value(),
|
|
"Error happened at preparing for default values for the argument. The ",
|
|
i,
|
|
"th arguement of operator",
|
|
opname,
|
|
" does not have a specified value or default value. ");
|
|
}
|
|
fn = [fn, num_specified_args, args](Stack& stack) {
|
|
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
|
|
stack.push_back(args[i].default_value());
|
|
}
|
|
fn(stack);
|
|
};
|
|
}
|
|
}
|
|
code_->operators_.emplace_back(fn);
|
|
return true;
|
|
}
|
|
|
|
void Function::append_constant(const c10::IValue& constant) {
|
|
code_->constants_.push_back(constant);
|
|
}
|
|
|
|
void Function::append_type(const at::TypePtr& type) {
|
|
code_->types_.push_back(type);
|
|
}
|
|
|
|
void Function::set_register_size(size_t size) {
|
|
code_->register_size_ = size;
|
|
}
|
|
|
|
int64_t Function::get_debug_handle(size_t pc) const {
|
|
TORCH_CHECK(code_, "Valid code must exist.");
|
|
TORCH_CHECK(
|
|
pc < code_->debug_handles_.size(),
|
|
"Module debug info index out of boundary.");
|
|
return code_->debug_handles_[pc];
|
|
}
|
|
|
|
void Function::setSchema(c10::FunctionSchema schema) {
|
|
schema_ = std::move(schema);
|
|
}
|
|
|
|
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
|
|
return schema_;
|
|
}
|
|
|
|
bool Function::run(Stack& stack) const {
|
|
const auto& schema = getSchema();
|
|
if (schema) { // if we have a schema then resolve optional args if any
|
|
schema->checkAndNormalizeInputs(
|
|
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
|
}
|
|
InterpreterState interp_state(code_);
|
|
return interp_state.run(stack);
|
|
}
|
|
|
|
c10::IValue Function::operator()(Stack& stack) const {
|
|
run(stack);
|
|
return stack.front();
|
|
}
|
|
|
|
const std::shared_ptr<Code> Function::get_code() const {
|
|
return code_;
|
|
}
|
|
|
|
int64_t Function::getExceptionDebugHandle() const {
|
|
size_t pc = getInterpretersExceptionPC();
|
|
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|