[quant][graphmode][refactor] Factor out getInvokedMethod (#33649)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33649

Test Plan:
.

Imported from OSS

Differential Revision: D20123589

fbshipit-source-id: 0853d757434fb85c6d86666ff9fc991f8c4cb4bc
This commit is contained in:
Jerry Zhang 2020-02-27 23:43:39 -08:00 committed by Facebook Github Bot
parent 7f1112820a
commit f5f1e5e7f6

View File

@ -232,6 +232,13 @@ std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
return path;
}
script::Module getInvokedModule(
script::Module& module, Node* n, Value* self) {
auto* instance = n->inputs()[0];
auto path = getModuleAccessPath(instance, self);
return findChildModule(module, path);
}
class ModuleCloneHelper {
public:
/** Clone according to module qconfig map, this is for handling the case
@ -622,21 +629,7 @@ ModuleMethodVector InsertObserversHelper::getInvokedMethods(
continue;
}
if (n->kind() == prim::CallMethod) {
// Record all method calls in the graph
auto module_instance = n->inputs()[0];
auto module_method_name = n->s(attr::name);
script::Module callee_module;
if (module_instance->node()->kind() == prim::GetAttr) {
auto child_module_name = module_instance->node()->s(attr::name);
callee_module = module.attr(child_module_name).toModule();
} else {
TORCH_INTERNAL_ASSERT(
module_instance == graph->inputs()[0],
"We only support call method either on %self"
"or child instance in insert_observers_pass right now");
callee_module = module;
}
invoked_methods.push_back({callee_module, module_method_name});
invoked_methods.push_back(std::make_pair(getInvokedModule(module, n, graph->inputs()[0]), n->s(attr::name)));
}
for (Block* subblock : n->blocks()) {