mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][graphmode][refactor] Factor out getInvokedMethod (#33649)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33649 Test Plan: . Imported from OSS Differential Revision: D20123589 fbshipit-source-id: 0853d757434fb85c6d86666ff9fc991f8c4cb4bc
This commit is contained in:
parent
7f1112820a
commit
f5f1e5e7f6
|
|
@ -232,6 +232,13 @@ std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
|
|||
return path;
|
||||
}
|
||||
|
||||
script::Module getInvokedModule(
|
||||
script::Module& module, Node* n, Value* self) {
|
||||
auto* instance = n->inputs()[0];
|
||||
auto path = getModuleAccessPath(instance, self);
|
||||
return findChildModule(module, path);
|
||||
}
|
||||
|
||||
class ModuleCloneHelper {
|
||||
public:
|
||||
/** Clone according to module qconfig map, this is for handling the case
|
||||
|
|
@ -622,21 +629,7 @@ ModuleMethodVector InsertObserversHelper::getInvokedMethods(
|
|||
continue;
|
||||
}
|
||||
if (n->kind() == prim::CallMethod) {
|
||||
// Record all method calls in the graph
|
||||
auto module_instance = n->inputs()[0];
|
||||
auto module_method_name = n->s(attr::name);
|
||||
script::Module callee_module;
|
||||
if (module_instance->node()->kind() == prim::GetAttr) {
|
||||
auto child_module_name = module_instance->node()->s(attr::name);
|
||||
callee_module = module.attr(child_module_name).toModule();
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
module_instance == graph->inputs()[0],
|
||||
"We only support call method either on %self"
|
||||
"or child instance in insert_observers_pass right now");
|
||||
callee_module = module;
|
||||
}
|
||||
invoked_methods.push_back({callee_module, module_method_name});
|
||||
invoked_methods.push_back(std::make_pair(getInvokedModule(module, n, graph->inputs()[0]), n->s(attr::name)));
|
||||
}
|
||||
|
||||
for (Block* subblock : n->blocks()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user