mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
interpreter handling for varargs to remove need for looking at Node (#32791)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32791 When a registered operator has varags (ends with ... in its schema), the interpreter now appends the number of arguments to the top of the stack before invoking the operator. This allows the removal of more uses of Node* in the interpreter. This PR also then cleans up the constructors for Operator to make it more likely someone chooses the correct one. After making these ops: ``` USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack ``` Into interpreter primitives, we can remove all but two constructors for operators: one that is (schema_string, operation), and one that is (symbol, op_creator) for the remaining weird primitives. Test Plan: Imported from OSS Differential Revision: D19673158 Pulled By: zdevito fbshipit-source-id: 95442a001538a6f53c1db4a210f8557ef118de66
This commit is contained in:
parent
da015c77a1
commit
c59e35b147
|
|
@ -29,6 +29,26 @@ white_list = [
|
|||
('prim::ImplicitTensorToNum', datetime.date(2020, 3, 1)),
|
||||
('aten::is_owner', datetime.date(2020, 3, 1)),
|
||||
('aten::to_here', datetime.date(2020, 3, 1)),
|
||||
('prim::isinstance', datetime.date(2020, 3, 1)),
|
||||
('prim::CreateObject', datetime.date(2020, 3, 1)),
|
||||
('prim::Uninitialized', datetime.date(2020, 3, 1)),
|
||||
('prim::fork', datetime.date(2020, 3, 1)),
|
||||
('prim::unchecked_cast', datetime.date(2020, 3, 1)),
|
||||
('prim::DictConstruct', datetime.date(2020, 3, 1)),
|
||||
('prim::ListConstruct', datetime.date(2020, 3, 1)),
|
||||
('prim::ListUnpack', datetime.date(2020, 3, 1)),
|
||||
('prim::TupleConstruct', datetime.date(2020, 3, 1)),
|
||||
('prim::TupleIndex', datetime.date(2020, 3, 1)),
|
||||
('prim::TupleSlice', datetime.date(2020, 3, 1)),
|
||||
('prim::TupleUnpack', datetime.date(2020, 3, 1)),
|
||||
('prim::AutogradAdd', datetime.date(2020, 3, 1)),
|
||||
('prim::AutogradAnyNonZero', datetime.date(2020, 3, 1)),
|
||||
('onnx::Shape', datetime.date(2020, 3, 1)),
|
||||
('onnx::Reshape', datetime.date(2020, 3, 1)),
|
||||
('prim::BroadcastSizes', datetime.date(2020, 3, 1)),
|
||||
('prim::Print', datetime.date(2020, 3, 1)),
|
||||
('prim::MMTreeReduce', datetime.date(2020, 3, 1)),
|
||||
('prim::Constant', datetime.date(2020, 3, 1)),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -90,10 +110,12 @@ if __name__ == '__main__':
|
|||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
if "torch.classes" in line or "RRef" in line:
|
||||
if "torch.classes" in line or "RRef" in line or "Any" in line:
|
||||
# TODO Fix type __torch__.torch.classes.xxx
|
||||
# TODO Delete RRef special case after add the RRef type
|
||||
# TODO: wait until nightly knows how to parse Any
|
||||
continue
|
||||
|
||||
s = parse_schema(line.strip())
|
||||
slist = new_schema_dict.get(s.name, [])
|
||||
slist.append(s)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ graph():
|
|||
%y : Tensor = aten::tensor(%3, %10, %7, %15)
|
||||
%9 : int[] = prim::ListConstruct(%1, %2)
|
||||
%z : Tensor = aten::tensor(%9, %10, %7, %15)
|
||||
%f = prim::Print(%x, %y, %z)
|
||||
prim::Print(%x, %y, %z)
|
||||
return (%1)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
|
|
|||
|
|
@ -142,13 +142,7 @@ c10::optional<Value*> tryInsertConstant(
|
|||
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
FunctionSchema(
|
||||
prim::Constant,
|
||||
"",
|
||||
{},
|
||||
{},
|
||||
/*is_vararg=*/false,
|
||||
/*is_varret=*/true),
|
||||
[](const Node* node) -> Operation {
|
||||
TypePtr type = node->output()->type();
|
||||
if (type->isSubtypeOf(TensorType::get())) {
|
||||
|
|
|
|||
|
|
@ -522,8 +522,13 @@ struct CodeImpl {
|
|||
|
||||
void emitOperator(Node* node) {
|
||||
emitLoadInputs(node->inputs());
|
||||
const Operator& op = node->getOperator();
|
||||
if (op.hasOperation() && op.schema().is_vararg()) {
|
||||
insertInstruction(OPN, operator_table_.size(), node->inputs().size());
|
||||
} else {
|
||||
insertInstruction(OP, operator_table_.size());
|
||||
operator_table_.emplace_back(node->getOperation());
|
||||
}
|
||||
operator_table_.emplace_back(op.getOperation(node));
|
||||
}
|
||||
|
||||
void emitWait(Node* node) {
|
||||
|
|
@ -757,7 +762,7 @@ struct CodeImpl {
|
|||
|
||||
void dump(std::ostream& out, size_t i) const {
|
||||
out << i << " " << instructions_[i];
|
||||
if (instructions_[i].op == OP || instructions_[i].op == CALL) {
|
||||
if (instructions_[i].op == OP || instructions_[i].op == CALL || instructions_[i].op == OPN) {
|
||||
out << " # " << *instructions_source_[i];
|
||||
} else {
|
||||
out << "\n";
|
||||
|
|
@ -890,7 +895,9 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
|||
++af.pc;
|
||||
break;
|
||||
case OPN:
|
||||
AT_ERROR("OPN is currently supported in mobile mode only.");
|
||||
stack.push_back(inst.N);
|
||||
af.operators[inst.X](stack);
|
||||
++af.pc;
|
||||
break;
|
||||
case LOAD:
|
||||
stack.emplace_back(reg(inst.X));
|
||||
|
|
|
|||
|
|
@ -69,14 +69,16 @@ struct TORCH_API Operator {
|
|||
c10Handle_(opHandle),
|
||||
options_(c10Handle_->options()) {}
|
||||
|
||||
|
||||
Operator(
|
||||
FunctionSchema schema,
|
||||
OperationCreator op_creator,
|
||||
const std::string& schema,
|
||||
int(*op)(Stack&),
|
||||
c10::OperatorOptions options = c10::OperatorOptions())
|
||||
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
||||
op_creator_(std::move(op_creator)),
|
||||
: schema_string_(schema),
|
||||
op_(std::make_shared<Operation>(std::move(op))),
|
||||
options_(std::move(options)) {}
|
||||
|
||||
|
||||
Operator(
|
||||
const std::string& schema,
|
||||
OperationCreator op_creator,
|
||||
|
|
@ -88,40 +90,13 @@ struct TORCH_API Operator {
|
|||
// Helper constructor to register `op` to run
|
||||
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
|
||||
// This is accomplished by marking the schema varargs and having no required
|
||||
// arguments. This is used for things like prim::While or prim::If that can
|
||||
// take a number of different valid input types and lengths.
|
||||
// arguments.
|
||||
Operator(
|
||||
Symbol name,
|
||||
OperationCreator op_creator,
|
||||
c10::OperatorOptions options = c10::OperatorOptions())
|
||||
: Operator(
|
||||
varArgSchemaWithName(name),
|
||||
std::move(op_creator),
|
||||
std::move(options)) {}
|
||||
|
||||
Operator(
|
||||
Symbol name,
|
||||
Operation op,
|
||||
c10::OperatorOptions options = c10::OperatorOptions())
|
||||
: Operator(
|
||||
varArgSchemaWithName(name),
|
||||
std::move(op),
|
||||
std::move(options)) {}
|
||||
|
||||
Operator(
|
||||
FunctionSchema schema,
|
||||
Operation op,
|
||||
c10::OperatorOptions options = c10::OperatorOptions())
|
||||
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
||||
op_(std::make_shared<Operation>(std::move(op))),
|
||||
options_(std::move(options)) {}
|
||||
|
||||
Operator(
|
||||
const std::string& schema,
|
||||
int(*op)(Stack&),
|
||||
c10::OperatorOptions options = c10::OperatorOptions())
|
||||
: schema_string_(schema),
|
||||
op_(std::make_shared<Operation>(std::move(op))),
|
||||
: schema_(std::make_shared<FunctionSchema>(varArgSchemaWithName(name))),
|
||||
op_creator_(std::move(op_creator)),
|
||||
options_(std::move(options)) {}
|
||||
|
||||
Operation getOperation(const Node* node = nullptr) const {
|
||||
|
|
@ -159,7 +134,9 @@ struct TORCH_API Operator {
|
|||
}
|
||||
return options_.aliasAnalysis();
|
||||
}
|
||||
|
||||
bool hasOperation() const {
|
||||
return op_ != nullptr;
|
||||
}
|
||||
private:
|
||||
static FunctionSchema varArgSchemaWithName(Symbol name) {
|
||||
return FunctionSchema(
|
||||
|
|
|
|||
|
|
@ -109,10 +109,9 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
|
|||
}
|
||||
|
||||
RegisterOperators mm_tree_reduction_reg({Operator(
|
||||
prim::MMTreeReduce,
|
||||
[](const Node* node) -> Operation {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
return [num_inputs](Stack& stack) {
|
||||
"prim::MMTreeReduce(...) -> Tensor",
|
||||
[](Stack& stack) {
|
||||
auto num_inputs = pop(stack).toInt();
|
||||
std::vector<at::Tensor> inputs;
|
||||
inputs.reserve(num_inputs);
|
||||
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
|
||||
|
|
@ -160,7 +159,6 @@ RegisterOperators mm_tree_reduction_reg({Operator(
|
|||
push(stack, std::move(acc));
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisIsSpecialCase())});
|
||||
|
||||
|
|
|
|||
|
|
@ -755,10 +755,9 @@ RegisterOperators reg(
|
|||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
prim::Print,
|
||||
[](const Node* node) -> Operation {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
return [num_inputs](Stack& stack) {
|
||||
"prim::Print(...) -> ()",
|
||||
[](Stack& stack) {
|
||||
auto num_inputs = pop(stack).toInt();
|
||||
std::stringstream ss;
|
||||
bool first = true;
|
||||
for (const IValue& i : last(stack, num_inputs)) {
|
||||
|
|
@ -773,24 +772,21 @@ RegisterOperators reg(
|
|||
TORCH_INTERNAL_ASSERT(handler);
|
||||
handler(ss.str());
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::BroadcastSizes,
|
||||
[](const Node* node) -> Operation {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
return [num_inputs](Stack& stack) {
|
||||
"prim::BroadcastSizes(...) -> int[]",
|
||||
[](Stack& stack) {
|
||||
auto num_inputs = pop(stack).toInt();
|
||||
std::vector<int64_t> size;
|
||||
size.reserve(8);
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
for (auto i = 0; i < num_inputs; ++i) {
|
||||
size = at::infer_size(
|
||||
size, peek(stack, i, num_inputs).toIntVector());
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
push(stack, IValue(std::move(size)));
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
|
|
@ -824,12 +820,7 @@ RegisterOperators reg(
|
|||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
FunctionSchema(
|
||||
"aten::warn",
|
||||
"",
|
||||
{Argument("message", StringType::get()),
|
||||
Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
|
||||
{}),
|
||||
"aten::warn(str message, int stacklevel=2) -> ()",
|
||||
[](const Node* node) -> Operation {
|
||||
auto range = node->sourceRange().source();
|
||||
if (range->filename()) {
|
||||
|
|
@ -873,7 +864,7 @@ RegisterOperators reg(
|
|||
aliasAnalysisFromSchema()),
|
||||
|
||||
Operator(
|
||||
c10::onnx::Reshape,
|
||||
"onnx::Reshape(Tensor input, Tensor shape) -> Tensor",
|
||||
[](Stack& stack) {
|
||||
at::Tensor input, shape;
|
||||
pop(stack, input, shape);
|
||||
|
|
@ -885,7 +876,7 @@ RegisterOperators reg(
|
|||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
c10::onnx::Shape,
|
||||
"onnx::Shape(Tensor t) -> Tensor",
|
||||
[](Stack& stack) {
|
||||
auto t = pop(stack).toTensor();
|
||||
at::IntArrayRef sizes = t.sizes();
|
||||
|
|
@ -901,9 +892,8 @@ RegisterOperators reg(
|
|||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
"prim::AutogradAnyNonZero(...) -> bool",
|
||||
[](const Node* node) -> Operation {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
return [num_inputs](Stack& stack) {
|
||||
[](Stack& stack) {
|
||||
auto num_inputs = pop(stack).toInt();
|
||||
bool result = false;
|
||||
for (const IValue& v : last(stack, num_inputs)) {
|
||||
if (v.isTensor()) {
|
||||
|
|
@ -927,11 +917,10 @@ RegisterOperators reg(
|
|||
drop(stack, num_inputs);
|
||||
stack.emplace_back(result);
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
prim::AutogradAdd,
|
||||
"prim::AutogradAdd(Any a, Any b) -> Any",
|
||||
[](Stack& stack) {
|
||||
at::Tensor a, b;
|
||||
pop(stack, a, b);
|
||||
|
|
@ -1008,7 +997,8 @@ RegisterOperators reg(
|
|||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::TupleIndex,
|
||||
// note the compiler knows to type TupleIndex more accurately than it is listed here.
|
||||
"prim::TupleIndex(Any tup, int i) -> Any",
|
||||
[](Stack& stack) {
|
||||
int64_t index = pop(stack).toInt();
|
||||
auto tuple = pop(stack).toTuple();
|
||||
|
|
@ -1225,7 +1215,7 @@ RegisterOperators reg(
|
|||
"prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)",
|
||||
noop,
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(prim::unchecked_cast, noop, aliasAnalysisSpecialCase()),
|
||||
Operator("prim::unchecked_cast(t x) -> t", noop, aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::fork,
|
||||
[](const Node* node) -> Operation {
|
||||
|
|
@ -1258,7 +1248,7 @@ RegisterOperators reg(
|
|||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::Uninitialized,
|
||||
"prim::Uninitialized() -> Any",
|
||||
[](Stack& stack) {
|
||||
push(stack, IValue::uninitialized());
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -264,12 +264,10 @@ RegisterOperators reg({
|
|||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::format(str self, ...) -> str",
|
||||
[](const Node* node) -> Operation {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
return [num_inputs](Stack& stack) {
|
||||
[](Stack& stack) {
|
||||
size_t num_inputs = pop(stack).toInt();
|
||||
formatFunc(num_inputs, stack);
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
|||
{"bool", BoolType::get()},
|
||||
{"None", NoneType::get()},
|
||||
{"Capsule", CapsuleType::get()},
|
||||
{"Any", at::AnyType::get()},
|
||||
};
|
||||
auto tok = L.cur();
|
||||
if (!L.nextIf(TK_NONE)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user