mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Update symbolic_trace nn_module_stack (#114422)
Summary:
Fixed nn_module_stack dynamo produced by symbolic trace to align with the nn_module_stack metadata produced by dynamo. The key should be the module path, with the value being a unique name, and the type. Something like: `{'L__self___one_module': ("L['self'].one_module", <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>)}`
This was causing some tests to fail when using export + the old quantization flow (prepare_fx calls symbolic_trace).
Test Plan: D51534471 `buck2 run @//mode/dev-nosan //executorch/backends/xnnpack/test:test_xnnpack_quantized -- -r "test_xnnpack_leaky_relu"`
Differential Revision: D51539118
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114422
Approved by: https://github.com/JacobSzwejbka, https://github.com/jerryzh168
This commit is contained in:
parent
f505d76462
commit
a0be4b7ea7
|
|
@ -1752,8 +1752,8 @@ class TestFX(JitTestCase):
|
|||
gm = torch.fx.symbolic_trace(m)
|
||||
|
||||
mod_stack = {}
|
||||
expected_stack = [('sub_mod', type(m.sub_mod)),
|
||||
('sub_mod.conv_mod', type(m.sub_mod.conv_mod))]
|
||||
expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
|
||||
('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
|
||||
for node in gm.graph.nodes:
|
||||
mod_stack = node.meta.get('nn_module_stack', {})
|
||||
if mod_stack:
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ class Tracer(TracerBase):
|
|||
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
|
||||
# module_stack is an ordered dict so writing then deleting the
|
||||
# entry is equivalent to push/pop on a list
|
||||
self.module_stack[_scope.module_path] = _scope.module_type
|
||||
self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
|
||||
if not self.is_leaf_module(m, module_qualified_name):
|
||||
ret_val = forward(*args, **kwargs)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class TracerBase:
|
|||
scope : Scope
|
||||
|
||||
# Records the module call stack
|
||||
module_stack: OrderedDict[str, str]
|
||||
module_stack: OrderedDict[str, Tuple[str, Any]]
|
||||
|
||||
# Mapping of node name to module scope
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user