interpreter handling for varargs to remove need for looking at Node (#32791)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32791

When a registered operator has varags (ends with ... in its schema),
the interpreter now appends the number of arguments to the top of
the stack before invoking the operator. This allows the removal of more
uses of Node* in the interpreter.

This PR also then cleans up the constructors for Operator to make
it more likely someone chooses the correct one. After making these ops:

```
USES NODE: prim::TupleUnpack(...) -> (...)
USES NODE: prim::TupleSlice(...) -> (...)
USES NODE: prim::TupleConstruct(...) -> (...)
USES NODE: prim::ListUnpack(...) -> (...)
USES NODE: prim::ListConstruct(...) -> (...)
USES NODE: prim::DictConstruct(...) -> (...)
USES NODE: prim::Constant() -> (...)
USES NODE: prim::isinstance(...) -> (...)
USES NODE: prim::CreateObject(...) -> (...)
USES NODE: prim::fork(...) -> (...)
USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack
```

Into interpreter primitives, we can remove all but two constructors for operators:
one that is (schema_string, operation), and one that is (symbol, op_creator) for
the remaining weird primitives.

Test Plan: Imported from OSS

Differential Revision: D19673158

Pulled By: zdevito

fbshipit-source-id: 95442a001538a6f53c1db4a210f8557ef118de66
This commit is contained in:
Zachary DeVito 2020-02-18 15:02:36 -08:00 committed by Facebook Github Bot
parent da015c77a1
commit c59e35b147
9 changed files with 133 additions and 146 deletions

View File

@ -29,6 +29,26 @@ white_list = [
('prim::ImplicitTensorToNum', datetime.date(2020, 3, 1)),
('aten::is_owner', datetime.date(2020, 3, 1)),
('aten::to_here', datetime.date(2020, 3, 1)),
('prim::isinstance', datetime.date(2020, 3, 1)),
('prim::CreateObject', datetime.date(2020, 3, 1)),
('prim::Uninitialized', datetime.date(2020, 3, 1)),
('prim::fork', datetime.date(2020, 3, 1)),
('prim::unchecked_cast', datetime.date(2020, 3, 1)),
('prim::DictConstruct', datetime.date(2020, 3, 1)),
('prim::ListConstruct', datetime.date(2020, 3, 1)),
('prim::ListUnpack', datetime.date(2020, 3, 1)),
('prim::TupleConstruct', datetime.date(2020, 3, 1)),
('prim::TupleIndex', datetime.date(2020, 3, 1)),
('prim::TupleSlice', datetime.date(2020, 3, 1)),
('prim::TupleUnpack', datetime.date(2020, 3, 1)),
('prim::AutogradAdd', datetime.date(2020, 3, 1)),
('prim::AutogradAnyNonZero', datetime.date(2020, 3, 1)),
('onnx::Shape', datetime.date(2020, 3, 1)),
('onnx::Reshape', datetime.date(2020, 3, 1)),
('prim::BroadcastSizes', datetime.date(2020, 3, 1)),
('prim::Print', datetime.date(2020, 3, 1)),
('prim::MMTreeReduce', datetime.date(2020, 3, 1)),
('prim::Constant', datetime.date(2020, 3, 1)),
]
@ -90,10 +110,12 @@ if __name__ == '__main__':
line = f.readline()
if not line:
break
if "torch.classes" in line or "RRef" in line:
if "torch.classes" in line or "RRef" in line or "Any" in line:
# TODO Fix type __torch__.torch.classes.xxx
# TODO Delete RRef special case after add the RRef type
# TODO: wait until nightly knows how to parse Any
continue
s = parse_schema(line.strip())
slist = new_schema_dict.get(s.name, [])
slist.append(s)

View File

@ -67,7 +67,7 @@ graph():
%y : Tensor = aten::tensor(%3, %10, %7, %15)
%9 : int[] = prim::ListConstruct(%1, %2)
%z : Tensor = aten::tensor(%9, %10, %7, %15)
%f = prim::Print(%x, %y, %z)
prim::Print(%x, %y, %z)
return (%1)
)IR",
&*graph);

View File

@ -142,13 +142,7 @@ c10::optional<Value*> tryInsertConstant(
RegisterOperators reg({
Operator(
FunctionSchema(
prim::Constant,
"",
{},
{},
/*is_vararg=*/false,
/*is_varret=*/true),
[](const Node* node) -> Operation {
TypePtr type = node->output()->type();
if (type->isSubtypeOf(TensorType::get())) {

View File

@ -522,8 +522,13 @@ struct CodeImpl {
void emitOperator(Node* node) {
emitLoadInputs(node->inputs());
const Operator& op = node->getOperator();
if (op.hasOperation() && op.schema().is_vararg()) {
insertInstruction(OPN, operator_table_.size(), node->inputs().size());
} else {
insertInstruction(OP, operator_table_.size());
operator_table_.emplace_back(node->getOperation());
}
operator_table_.emplace_back(op.getOperation(node));
}
void emitWait(Node* node) {
@ -757,7 +762,7 @@ struct CodeImpl {
void dump(std::ostream& out, size_t i) const {
out << i << " " << instructions_[i];
if (instructions_[i].op == OP || instructions_[i].op == CALL) {
if (instructions_[i].op == OP || instructions_[i].op == CALL || instructions_[i].op == OPN) {
out << " # " << *instructions_source_[i];
} else {
out << "\n";
@ -890,7 +895,9 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
++af.pc;
break;
case OPN:
AT_ERROR("OPN is currently supported in mobile mode only.");
stack.push_back(inst.N);
af.operators[inst.X](stack);
++af.pc;
break;
case LOAD:
stack.emplace_back(reg(inst.X));

View File

@ -69,14 +69,16 @@ struct TORCH_API Operator {
c10Handle_(opHandle),
options_(c10Handle_->options()) {}
Operator(
FunctionSchema schema,
OperationCreator op_creator,
const std::string& schema,
int(*op)(Stack&),
c10::OperatorOptions options = c10::OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_creator_(std::move(op_creator)),
: schema_string_(schema),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
Operator(
const std::string& schema,
OperationCreator op_creator,
@ -88,40 +90,13 @@ struct TORCH_API Operator {
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
// arguments.
Operator(
Symbol name,
OperationCreator op_creator,
c10::OperatorOptions options = c10::OperatorOptions())
: Operator(
varArgSchemaWithName(name),
std::move(op_creator),
std::move(options)) {}
Operator(
Symbol name,
Operation op,
c10::OperatorOptions options = c10::OperatorOptions())
: Operator(
varArgSchemaWithName(name),
std::move(op),
std::move(options)) {}
Operator(
FunctionSchema schema,
Operation op,
c10::OperatorOptions options = c10::OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
Operator(
const std::string& schema,
int(*op)(Stack&),
c10::OperatorOptions options = c10::OperatorOptions())
: schema_string_(schema),
op_(std::make_shared<Operation>(std::move(op))),
: schema_(std::make_shared<FunctionSchema>(varArgSchemaWithName(name))),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
Operation getOperation(const Node* node = nullptr) const {
@ -159,7 +134,9 @@ struct TORCH_API Operator {
}
return options_.aliasAnalysis();
}
bool hasOperation() const {
return op_ != nullptr;
}
private:
static FunctionSchema varArgSchemaWithName(Symbol name) {
return FunctionSchema(

View File

@ -109,10 +109,9 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
}
RegisterOperators mm_tree_reduction_reg({Operator(
prim::MMTreeReduce,
[](const Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
"prim::MMTreeReduce(...) -> Tensor",
[](Stack& stack) {
auto num_inputs = pop(stack).toInt();
std::vector<at::Tensor> inputs;
inputs.reserve(num_inputs);
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
@ -160,7 +159,6 @@ RegisterOperators mm_tree_reduction_reg({Operator(
push(stack, std::move(acc));
}
return 0;
};
},
aliasAnalysisIsSpecialCase())});

View File

@ -755,10 +755,9 @@ RegisterOperators reg(
},
aliasAnalysisFromSchema()),
Operator(
prim::Print,
[](const Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
"prim::Print(...) -> ()",
[](Stack& stack) {
auto num_inputs = pop(stack).toInt();
std::stringstream ss;
bool first = true;
for (const IValue& i : last(stack, num_inputs)) {
@ -773,24 +772,21 @@ RegisterOperators reg(
TORCH_INTERNAL_ASSERT(handler);
handler(ss.str());
return 0;
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::BroadcastSizes,
[](const Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
"prim::BroadcastSizes(...) -> int[]",
[](Stack& stack) {
auto num_inputs = pop(stack).toInt();
std::vector<int64_t> size;
size.reserve(8);
for (size_t i = 0; i < num_inputs; ++i) {
for (auto i = 0; i < num_inputs; ++i) {
size = at::infer_size(
size, peek(stack, i, num_inputs).toIntVector());
}
drop(stack, num_inputs);
push(stack, IValue(std::move(size)));
return 0;
};
},
aliasAnalysisSpecialCase()),
Operator(
@ -824,12 +820,7 @@ RegisterOperators reg(
},
aliasAnalysisSpecialCase()),
Operator(
FunctionSchema(
"aten::warn",
"",
{Argument("message", StringType::get()),
Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
{}),
"aten::warn(str message, int stacklevel=2) -> ()",
[](const Node* node) -> Operation {
auto range = node->sourceRange().source();
if (range->filename()) {
@ -873,7 +864,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
Operator(
c10::onnx::Reshape,
"onnx::Reshape(Tensor input, Tensor shape) -> Tensor",
[](Stack& stack) {
at::Tensor input, shape;
pop(stack, input, shape);
@ -885,7 +876,7 @@ RegisterOperators reg(
},
aliasAnalysisSpecialCase()),
Operator(
c10::onnx::Shape,
"onnx::Shape(Tensor t) -> Tensor",
[](Stack& stack) {
auto t = pop(stack).toTensor();
at::IntArrayRef sizes = t.sizes();
@ -901,9 +892,8 @@ RegisterOperators reg(
aliasAnalysisSpecialCase()),
Operator(
"prim::AutogradAnyNonZero(...) -> bool",
[](const Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
[](Stack& stack) {
auto num_inputs = pop(stack).toInt();
bool result = false;
for (const IValue& v : last(stack, num_inputs)) {
if (v.isTensor()) {
@ -927,11 +917,10 @@ RegisterOperators reg(
drop(stack, num_inputs);
stack.emplace_back(result);
return 0;
};
},
aliasAnalysisFromSchema()),
Operator(
prim::AutogradAdd,
"prim::AutogradAdd(Any a, Any b) -> Any",
[](Stack& stack) {
at::Tensor a, b;
pop(stack, a, b);
@ -1008,7 +997,8 @@ RegisterOperators reg(
},
aliasAnalysisSpecialCase()),
Operator(
prim::TupleIndex,
// note the compiler knows to type TupleIndex more accurately than it is listed here.
"prim::TupleIndex(Any tup, int i) -> Any",
[](Stack& stack) {
int64_t index = pop(stack).toInt();
auto tuple = pop(stack).toTuple();
@ -1225,7 +1215,7 @@ RegisterOperators reg(
"prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)",
noop,
aliasAnalysisFromSchema()),
Operator(prim::unchecked_cast, noop, aliasAnalysisSpecialCase()),
Operator("prim::unchecked_cast(t x) -> t", noop, aliasAnalysisSpecialCase()),
Operator(
prim::fork,
[](const Node* node) -> Operation {
@ -1258,7 +1248,7 @@ RegisterOperators reg(
},
aliasAnalysisSpecialCase()),
Operator(
prim::Uninitialized,
"prim::Uninitialized() -> Any",
[](Stack& stack) {
push(stack, IValue::uninitialized());
return 0;

View File

@ -264,12 +264,10 @@ RegisterOperators reg({
aliasAnalysisFromSchema()),
Operator(
"aten::format(str self, ...) -> str",
[](const Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
[](Stack& stack) {
size_t num_inputs = pop(stack).toInt();
formatFunc(num_inputs, stack);
return 0;
};
},
aliasAnalysisFromSchema()),

View File

@ -54,6 +54,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
{"bool", BoolType::get()},
{"None", NoneType::get()},
{"Capsule", CapsuleType::get()},
{"Any", at::AnyType::get()},
};
auto tok = L.cur();
if (!L.nextIf(TK_NONE)) {